You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 46 kB

5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License foNtest_resr the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Built-in validators.
  17. """
  18. import inspect as ins
  19. import os
  20. import re
  21. from functools import wraps
  22. import numpy as np
  23. from mindspore._c_expression import typing
  24. from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
  25. INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
  26. validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
  27. check_columns, check_pos_int32, check_valid_str
  28. from . import datasets
  29. from . import samplers
  30. def check_imagefolderdataset(method):
  31. """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset)."""
  32. @wraps(method)
  33. def new_method(self, *args, **kwargs):
  34. _, param_dict = parse_user_args(method, *args, **kwargs)
  35. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  36. nreq_param_bool = ['shuffle', 'decode']
  37. nreq_param_list = ['extensions']
  38. nreq_param_dict = ['class_indexing']
  39. dataset_dir = param_dict.get('dataset_dir')
  40. check_dir(dataset_dir)
  41. validate_dataset_param_value(nreq_param_int, param_dict, int)
  42. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  43. validate_dataset_param_value(nreq_param_list, param_dict, list)
  44. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  45. check_sampler_shuffle_shard_options(param_dict)
  46. cache = param_dict.get('cache')
  47. check_cache_option(cache)
  48. return method(self, *args, **kwargs)
  49. return new_method
  50. def check_mnist_cifar_dataset(method):
  51. """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
  52. @wraps(method)
  53. def new_method(self, *args, **kwargs):
  54. _, param_dict = parse_user_args(method, *args, **kwargs)
  55. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  56. nreq_param_bool = ['shuffle']
  57. dataset_dir = param_dict.get('dataset_dir')
  58. check_dir(dataset_dir)
  59. usage = param_dict.get('usage')
  60. if usage is not None:
  61. check_valid_str(usage, ["train", "test", "all"], "usage")
  62. validate_dataset_param_value(nreq_param_int, param_dict, int)
  63. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  64. check_sampler_shuffle_shard_options(param_dict)
  65. cache = param_dict.get('cache')
  66. check_cache_option(cache)
  67. return method(self, *args, **kwargs)
  68. return new_method
  69. def check_manifestdataset(method):
  70. """A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
  71. @wraps(method)
  72. def new_method(self, *args, **kwargs):
  73. _, param_dict = parse_user_args(method, *args, **kwargs)
  74. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  75. nreq_param_bool = ['shuffle', 'decode']
  76. nreq_param_str = ['usage']
  77. nreq_param_dict = ['class_indexing']
  78. dataset_file = param_dict.get('dataset_file')
  79. check_file(dataset_file)
  80. validate_dataset_param_value(nreq_param_int, param_dict, int)
  81. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  82. validate_dataset_param_value(nreq_param_str, param_dict, str)
  83. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  84. check_sampler_shuffle_shard_options(param_dict)
  85. cache = param_dict.get('cache')
  86. check_cache_option(cache)
  87. return method(self, *args, **kwargs)
  88. return new_method
  89. def check_tfrecorddataset(method):
  90. """A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
  91. @wraps(method)
  92. def new_method(self, *args, **kwargs):
  93. _, param_dict = parse_user_args(method, *args, **kwargs)
  94. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  95. nreq_param_list = ['columns_list']
  96. nreq_param_bool = ['shard_equal_rows']
  97. dataset_files = param_dict.get('dataset_files')
  98. if not isinstance(dataset_files, (str, list)):
  99. raise TypeError("dataset_files should be of type str or a list of strings.")
  100. validate_dataset_param_value(nreq_param_int, param_dict, int)
  101. validate_dataset_param_value(nreq_param_list, param_dict, list)
  102. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  103. check_sampler_shuffle_shard_options(param_dict)
  104. cache = param_dict.get('cache')
  105. check_cache_option(cache)
  106. return method(self, *args, **kwargs)
  107. return new_method
  108. def check_vocdataset(method):
  109. """A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
  110. @wraps(method)
  111. def new_method(self, *args, **kwargs):
  112. _, param_dict = parse_user_args(method, *args, **kwargs)
  113. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  114. nreq_param_bool = ['shuffle', 'decode']
  115. nreq_param_dict = ['class_indexing']
  116. dataset_dir = param_dict.get('dataset_dir')
  117. check_dir(dataset_dir)
  118. task = param_dict.get('task')
  119. type_check(task, (str,), "task")
  120. usage = param_dict.get('usage')
  121. type_check(usage, (str,), "usage")
  122. if task == "Segmentation":
  123. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
  124. if param_dict.get('class_indexing') is not None:
  125. raise ValueError("class_indexing is invalid in Segmentation task")
  126. elif task == "Detection":
  127. imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
  128. else:
  129. raise ValueError("Invalid task : " + task)
  130. check_file(imagesets_file)
  131. validate_dataset_param_value(nreq_param_int, param_dict, int)
  132. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  133. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  134. check_sampler_shuffle_shard_options(param_dict)
  135. cache = param_dict.get('cache')
  136. check_cache_option(cache)
  137. return method(self, *args, **kwargs)
  138. return new_method
  139. def check_cocodataset(method):
  140. """A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
  141. @wraps(method)
  142. def new_method(self, *args, **kwargs):
  143. _, param_dict = parse_user_args(method, *args, **kwargs)
  144. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  145. nreq_param_bool = ['shuffle', 'decode']
  146. dataset_dir = param_dict.get('dataset_dir')
  147. check_dir(dataset_dir)
  148. annotation_file = param_dict.get('annotation_file')
  149. check_file(annotation_file)
  150. task = param_dict.get('task')
  151. type_check(task, (str,), "task")
  152. if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint'}:
  153. raise ValueError("Invalid task type")
  154. validate_dataset_param_value(nreq_param_int, param_dict, int)
  155. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  156. sampler = param_dict.get('sampler')
  157. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  158. raise ValueError("CocoDataset doesn't support PKSampler")
  159. check_sampler_shuffle_shard_options(param_dict)
  160. cache = param_dict.get('cache')
  161. check_cache_option(cache)
  162. return method(self, *args, **kwargs)
  163. return new_method
  164. def check_celebadataset(method):
  165. """A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
  166. @wraps(method)
  167. def new_method(self, *args, **kwargs):
  168. _, param_dict = parse_user_args(method, *args, **kwargs)
  169. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  170. nreq_param_bool = ['shuffle', 'decode']
  171. nreq_param_list = ['extensions']
  172. nreq_param_str = ['dataset_type']
  173. dataset_dir = param_dict.get('dataset_dir')
  174. check_dir(dataset_dir)
  175. validate_dataset_param_value(nreq_param_int, param_dict, int)
  176. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  177. validate_dataset_param_value(nreq_param_list, param_dict, list)
  178. validate_dataset_param_value(nreq_param_str, param_dict, str)
  179. usage = param_dict.get('usage')
  180. if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
  181. raise ValueError("usage should be one of 'all', 'train', 'valid' or 'test'.")
  182. check_sampler_shuffle_shard_options(param_dict)
  183. sampler = param_dict.get('sampler')
  184. if sampler is not None and isinstance(sampler, samplers.PKSampler):
  185. raise ValueError("CelebADataset does not support PKSampler.")
  186. cache = param_dict.get('cache')
  187. check_cache_option(cache)
  188. return method(self, *args, **kwargs)
  189. return new_method
  190. def check_save(method):
  191. """A wrapper that wraps a parameter checker around the saved operator."""
  192. @wraps(method)
  193. def new_method(self, *args, **kwargs):
  194. _, param_dict = parse_user_args(method, *args, **kwargs)
  195. nreq_param_int = ['num_files']
  196. nreq_param_str = ['file_name', 'file_type']
  197. validate_dataset_param_value(nreq_param_int, param_dict, int)
  198. if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
  199. raise ValueError("num_files should between {} and {}.".format(1, 1000))
  200. validate_dataset_param_value(nreq_param_str, param_dict, str)
  201. if param_dict.get('file_type') != 'mindrecord':
  202. raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
  203. return method(self, *args, **kwargs)
  204. return new_method
  205. def check_iterator(method):
  206. """A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
  207. @wraps(method)
  208. def new_method(self, *args, **kwargs):
  209. _, param_dict = parse_user_args(method, *args, **kwargs)
  210. nreq_param_bool = ['output_numpy']
  211. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  212. return method(self, *args, **kwargs)
  213. return new_method
  214. def check_minddataset(method):
  215. """A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
  216. @wraps(method)
  217. def new_method(self, *args, **kwargs):
  218. _, param_dict = parse_user_args(method, *args, **kwargs)
  219. nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
  220. nreq_param_list = ['columns_list']
  221. nreq_param_dict = ['padded_sample']
  222. dataset_file = param_dict.get('dataset_file')
  223. if isinstance(dataset_file, list):
  224. if len(dataset_file) > 4096:
  225. raise ValueError("length of dataset_file should less than or equal to {}.".format(4096))
  226. for f in dataset_file:
  227. check_file(f)
  228. else:
  229. check_file(dataset_file)
  230. validate_dataset_param_value(nreq_param_int, param_dict, int)
  231. validate_dataset_param_value(nreq_param_list, param_dict, list)
  232. validate_dataset_param_value(nreq_param_dict, param_dict, dict)
  233. check_sampler_shuffle_shard_options(param_dict)
  234. check_padding_options(param_dict)
  235. return method(self, *args, **kwargs)
  236. return new_method
  237. def check_generatordataset(method):
  238. """A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
  239. @wraps(method)
  240. def new_method(self, *args, **kwargs):
  241. _, param_dict = parse_user_args(method, *args, **kwargs)
  242. source = param_dict.get('source')
  243. if not callable(source):
  244. try:
  245. iter(source)
  246. except TypeError:
  247. raise TypeError("source should be callable, iterable or random accessible")
  248. column_names = param_dict.get('column_names')
  249. if column_names is not None:
  250. check_columns(column_names, "column_names")
  251. schema = param_dict.get('schema')
  252. if column_names is None and schema is None:
  253. raise ValueError("Neither columns_names not schema are provided.")
  254. if schema is not None:
  255. if not isinstance(schema, datasets.Schema) and not isinstance(schema, str):
  256. raise ValueError("schema should be a path to schema file or a schema object.")
  257. # check optional argument
  258. nreq_param_int = ["num_samples", "num_parallel_workers", "num_shards", "shard_id"]
  259. validate_dataset_param_value(nreq_param_int, param_dict, int)
  260. nreq_param_list = ["column_types"]
  261. validate_dataset_param_value(nreq_param_list, param_dict, list)
  262. nreq_param_bool = ["shuffle"]
  263. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  264. num_shards = param_dict.get("num_shards")
  265. shard_id = param_dict.get("shard_id")
  266. if (num_shards is None) != (shard_id is None):
  267. # These two parameters appear together.
  268. raise ValueError("num_shards and shard_id need to be passed in together")
  269. if num_shards is not None:
  270. check_pos_int32(num_shards, "num_shards")
  271. if shard_id >= num_shards:
  272. raise ValueError("shard_id should be less than num_shards.")
  273. sampler = param_dict.get("sampler")
  274. if sampler is not None:
  275. if isinstance(sampler, samplers.PKSampler):
  276. raise ValueError("PKSampler is not supported by GeneratorDataset")
  277. if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler,
  278. samplers.RandomSampler, samplers.SubsetRandomSampler,
  279. samplers.WeightedRandomSampler, samplers.Sampler)):
  280. try:
  281. iter(sampler)
  282. except TypeError:
  283. raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers")
  284. if sampler is not None and not hasattr(source, "__getitem__"):
  285. raise ValueError("sampler is not supported if source does not have attribute '__getitem__'")
  286. if num_shards is not None and not hasattr(source, "__getitem__"):
  287. raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'")
  288. return method(self, *args, **kwargs)
  289. return new_method
  290. def check_random_dataset(method):
  291. """A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
  292. @wraps(method)
  293. def new_method(self, *args, **kwargs):
  294. _, param_dict = parse_user_args(method, *args, **kwargs)
  295. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
  296. nreq_param_bool = ['shuffle']
  297. nreq_param_list = ['columns_list']
  298. validate_dataset_param_value(nreq_param_int, param_dict, int)
  299. validate_dataset_param_value(nreq_param_bool, param_dict, bool)
  300. validate_dataset_param_value(nreq_param_list, param_dict, list)
  301. check_sampler_shuffle_shard_options(param_dict)
  302. cache = param_dict.get('cache')
  303. check_cache_option(cache)
  304. return method(self, *args, **kwargs)
  305. return new_method
  306. def check_pad_info(key, val):
  307. """check the key and value pair of pad_info in batch"""
  308. type_check(key, (str,), "key in pad_info")
  309. if val is not None:
  310. assert len(val) == 2, "value of pad_info should be a tuple of size 2"
  311. type_check(val, (tuple,), "value in pad_info")
  312. if val[0] is not None:
  313. type_check(val[0], (list,), "pad_shape")
  314. for dim in val[0]:
  315. if dim is not None:
  316. type_check(dim, (int,), "dim in pad_shape")
  317. assert dim > 0, "pad shape should be positive integers"
  318. if val[1] is not None:
  319. type_check(val[1], (int, float, str, bytes), "pad_value")
  320. def check_bucket_batch_by_length(method):
  321. """check the input arguments of bucket_batch_by_length."""
  322. @wraps(method)
  323. def new_method(self, *args, **kwargs):
  324. [column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
  325. pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
  326. nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
  327. type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
  328. nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
  329. type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
  330. # check column_names: must be list of string.
  331. check_columns(column_names, "column_names")
  332. if element_length_function is None and len(column_names) != 1:
  333. raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
  334. # check bucket_boundaries: must be list of int, positive and strictly increasing
  335. if not bucket_boundaries:
  336. raise ValueError("bucket_boundaries cannot be empty.")
  337. all_int = all(isinstance(item, int) for item in bucket_boundaries)
  338. if not all_int:
  339. raise TypeError("bucket_boundaries should be a list of int.")
  340. all_non_negative = all(item > 0 for item in bucket_boundaries)
  341. if not all_non_negative:
  342. raise ValueError("bucket_boundaries must only contain positive numbers.")
  343. for i in range(len(bucket_boundaries) - 1):
  344. if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
  345. raise ValueError("bucket_boundaries should be strictly increasing.")
  346. # check bucket_batch_sizes: must be list of int and positive
  347. if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
  348. raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
  349. all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
  350. if not all_int:
  351. raise TypeError("bucket_batch_sizes should be a list of int.")
  352. all_non_negative = all(item > 0 for item in bucket_batch_sizes)
  353. if not all_non_negative:
  354. raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
  355. if pad_info is not None:
  356. type_check(pad_info, (dict,), "pad_info")
  357. for k, v in pad_info.items():
  358. check_pad_info(k, v)
  359. return method(self, *args, **kwargs)
  360. return new_method
  361. def check_batch(method):
  362. """check the input arguments of batch."""
  363. @wraps(method)
  364. def new_method(self, *args, **kwargs):
  365. [batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns, output_columns,
  366. column_order, pad_info], param_dict = parse_user_args(method, *args, **kwargs)
  367. if not (isinstance(batch_size, int) or (callable(batch_size))):
  368. raise TypeError("batch_size should either be an int or a callable.")
  369. if callable(batch_size):
  370. sig = ins.signature(batch_size)
  371. if len(sig.parameters) != 1:
  372. raise ValueError("batch_size callable should take one parameter (BatchInfo).")
  373. if num_parallel_workers is not None:
  374. check_num_parallel_workers(num_parallel_workers)
  375. type_check(drop_remainder, (bool,), "drop_remainder")
  376. if (pad_info is not None) and (per_batch_map is not None):
  377. raise ValueError("pad_info and per_batch_map can't both be set")
  378. if pad_info is not None:
  379. type_check(param_dict["pad_info"], (dict,), "pad_info")
  380. for k, v in param_dict.get('pad_info').items():
  381. check_pad_info(k, v)
  382. if (per_batch_map is None) != (input_columns is None):
  383. # These two parameters appear together.
  384. raise ValueError("per_batch_map and input_columns need to be passed in together.")
  385. if input_columns is not None:
  386. check_columns(input_columns, "input_columns")
  387. if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
  388. raise ValueError("the signature of per_batch_map should match with input columns")
  389. if output_columns is not None:
  390. check_columns(output_columns, "output_columns")
  391. if column_order is not None:
  392. check_columns(column_order, "column_order")
  393. return method(self, *args, **kwargs)
  394. return new_method
  395. def check_sync_wait(method):
  396. """check the input arguments of sync_wait."""
  397. @wraps(method)
  398. def new_method(self, *args, **kwargs):
  399. [condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
  400. type_check(condition_name, (str,), "condition_name")
  401. type_check(num_batch, (int,), "num_batch")
  402. return method(self, *args, **kwargs)
  403. return new_method
  404. def check_shuffle(method):
  405. """check the input arguments of shuffle."""
  406. @wraps(method)
  407. def new_method(self, *args, **kwargs):
  408. [buffer_size], _ = parse_user_args(method, *args, **kwargs)
  409. type_check(buffer_size, (int,), "buffer_size")
  410. check_value(buffer_size, [2, INT32_MAX], "buffer_size")
  411. return method(self, *args, **kwargs)
  412. return new_method
  413. def check_map(method):
  414. """check the input arguments of map."""
  415. @wraps(method)
  416. def new_method(self, *args, **kwargs):
  417. from mindspore.dataset.callback import DSCallback
  418. [_, input_columns, output_columns, column_order, num_parallel_workers, python_multiprocessing, cache,
  419. callbacks], _ = \
  420. parse_user_args(method, *args, **kwargs)
  421. nreq_param_columns = ['input_columns', 'output_columns', 'column_order']
  422. if column_order is not None:
  423. type_check(column_order, (list,), "column_order")
  424. if num_parallel_workers is not None:
  425. check_num_parallel_workers(num_parallel_workers)
  426. type_check(python_multiprocessing, (bool,), "python_multiprocessing")
  427. check_cache_option(cache)
  428. if callbacks is not None:
  429. if isinstance(callbacks, (list, tuple)):
  430. type_check_list(callbacks, (DSCallback,), "callbacks")
  431. else:
  432. type_check(callbacks, (DSCallback,), "callbacks")
  433. for param_name, param in zip(nreq_param_columns, [input_columns, output_columns, column_order]):
  434. if param is not None:
  435. check_columns(param, param_name)
  436. if callbacks is not None:
  437. type_check(callbacks, (list, DSCallback), "callbacks")
  438. return method(self, *args, **kwargs)
  439. return new_method
  440. def check_filter(method):
  441. """"check the input arguments of filter."""
  442. @wraps(method)
  443. def new_method(self, *args, **kwargs):
  444. [predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
  445. if not callable(predicate):
  446. raise TypeError("Predicate should be a Python function or a callable Python object.")
  447. check_num_parallel_workers(num_parallel_workers)
  448. if num_parallel_workers is not None:
  449. check_num_parallel_workers(num_parallel_workers)
  450. if input_columns is not None:
  451. check_columns(input_columns, "input_columns")
  452. return method(self, *args, **kwargs)
  453. return new_method
  454. def check_repeat(method):
  455. """check the input arguments of repeat."""
  456. @wraps(method)
  457. def new_method(self, *args, **kwargs):
  458. [count], _ = parse_user_args(method, *args, **kwargs)
  459. type_check(count, (int, type(None)), "repeat")
  460. if isinstance(count, int):
  461. if (count <= 0 and count != -1) or count > INT32_MAX:
  462. raise ValueError("count should be either -1 or positive integer.")
  463. return method(self, *args, **kwargs)
  464. return new_method
  465. def check_skip(method):
  466. """check the input arguments of skip."""
  467. @wraps(method)
  468. def new_method(self, *args, **kwargs):
  469. [count], _ = parse_user_args(method, *args, **kwargs)
  470. type_check(count, (int,), "count")
  471. check_value(count, (-1, INT32_MAX), "count")
  472. return method(self, *args, **kwargs)
  473. return new_method
  474. def check_take(method):
  475. """check the input arguments of take."""
  476. @wraps(method)
  477. def new_method(self, *args, **kwargs):
  478. [count], _ = parse_user_args(method, *args, **kwargs)
  479. type_check(count, (int,), "count")
  480. if (count <= 0 and count != -1) or count > INT32_MAX:
  481. raise ValueError("count should be either -1 or positive integer.")
  482. return method(self, *args, **kwargs)
  483. return new_method
  484. def check_positive_int32(method):
  485. """check whether the input argument is positive and int, only works for functions with one input."""
  486. @wraps(method)
  487. def new_method(self, *args, **kwargs):
  488. [count], param_dict = parse_user_args(method, *args, **kwargs)
  489. para_name = None
  490. for key in list(param_dict.keys()):
  491. if key not in ['self', 'cls']:
  492. para_name = key
  493. # Need to get default value of param
  494. if count is not None:
  495. check_pos_int32(count, para_name)
  496. return method(self, *args, **kwargs)
  497. return new_method
  498. def check_device_send(method):
  499. """check the input argument for to_device and device_que."""
  500. @wraps(method)
  501. def new_method(self, *args, **kwargs):
  502. param, param_dict = parse_user_args(method, *args, **kwargs)
  503. para_list = list(param_dict.keys())
  504. if "prefetch_size" in para_list:
  505. if param[0] is not None:
  506. check_pos_int32(param[0], "prefetch_size")
  507. type_check(param[1], (bool,), "send_epoch_end")
  508. else:
  509. type_check(param[0], (bool,), "send_epoch_end")
  510. return method(self, *args, **kwargs)
  511. return new_method
  512. def check_zip(method):
  513. """check the input arguments of zip."""
  514. @wraps(method)
  515. def new_method(*args, **kwargs):
  516. [ds], _ = parse_user_args(method, *args, **kwargs)
  517. type_check(ds, (tuple,), "datasets")
  518. return method(*args, **kwargs)
  519. return new_method
  520. def check_zip_dataset(method):
  521. """check the input arguments of zip method in `Dataset`."""
  522. @wraps(method)
  523. def new_method(self, *args, **kwargs):
  524. [ds], _ = parse_user_args(method, *args, **kwargs)
  525. type_check(ds, (tuple, datasets.Dataset), "datasets")
  526. return method(self, *args, **kwargs)
  527. return new_method
  528. def check_concat(method):
  529. """check the input arguments of concat method in `Dataset`."""
  530. @wraps(method)
  531. def new_method(self, *args, **kwargs):
  532. [ds], _ = parse_user_args(method, *args, **kwargs)
  533. type_check(ds, (list, datasets.Dataset), "datasets")
  534. if isinstance(ds, list):
  535. type_check_list(ds, (datasets.Dataset,), "dataset")
  536. return method(self, *args, **kwargs)
  537. return new_method
  538. def check_rename(method):
  539. """check the input arguments of rename."""
  540. @wraps(method)
  541. def new_method(self, *args, **kwargs):
  542. values, _ = parse_user_args(method, *args, **kwargs)
  543. req_param_columns = ['input_columns', 'output_columns']
  544. for param_name, param in zip(req_param_columns, values):
  545. check_columns(param, param_name)
  546. input_size, output_size = 1, 1
  547. input_columns, output_columns = values
  548. if isinstance(input_columns, list):
  549. input_size = len(input_columns)
  550. if isinstance(output_columns, list):
  551. output_size = len(output_columns)
  552. if input_size != output_size:
  553. raise ValueError("Number of column in input_columns and output_columns is not equal.")
  554. return method(self, *args, **kwargs)
  555. return new_method
  556. def check_project(method):
  557. """check the input arguments of project."""
  558. @wraps(method)
  559. def new_method(self, *args, **kwargs):
  560. [columns], _ = parse_user_args(method, *args, **kwargs)
  561. check_columns(columns, 'columns')
  562. return method(self, *args, **kwargs)
  563. return new_method
  564. def check_add_column(method):
  565. """check the input arguments of add_column."""
  566. @wraps(method)
  567. def new_method(self, *args, **kwargs):
  568. [name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
  569. type_check(name, (str,), "name")
  570. if not name:
  571. raise TypeError("Expected non-empty string.")
  572. if de_type is not None:
  573. if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
  574. raise TypeError("Unknown column type.")
  575. else:
  576. raise TypeError("Expected non-empty string.")
  577. if shape is not None:
  578. type_check(shape, (list,), "shape")
  579. type_check_list(shape, (int,), "shape")
  580. return method(self, *args, **kwargs)
  581. return new_method
  582. def check_cluedataset(method):
  583. """A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
  584. @wraps(method)
  585. def new_method(self, *args, **kwargs):
  586. _, param_dict = parse_user_args(method, *args, **kwargs)
  587. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  588. dataset_files = param_dict.get('dataset_files')
  589. type_check(dataset_files, (str, list), "dataset files")
  590. # check task
  591. task_param = param_dict.get('task')
  592. if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
  593. raise ValueError("task should be AFQMC, TNEWS, IFLYTEK, CMNLI, WSC or CSL")
  594. # check usage
  595. usage_param = param_dict.get('usage')
  596. if usage_param not in ['train', 'test', 'eval']:
  597. raise ValueError("usage should be train, test or eval")
  598. validate_dataset_param_value(nreq_param_int, param_dict, int)
  599. check_sampler_shuffle_shard_options(param_dict)
  600. cache = param_dict.get('cache')
  601. check_cache_option(cache)
  602. return method(self, *args, **kwargs)
  603. return new_method
  604. def check_csvdataset(method):
  605. """A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
  606. @wraps(method)
  607. def new_method(self, *args, **kwargs):
  608. _, param_dict = parse_user_args(method, *args, **kwargs)
  609. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  610. # check dataset_files; required argument
  611. dataset_files = param_dict.get('dataset_files')
  612. type_check(dataset_files, (str, list), "dataset files")
  613. # check field_delim
  614. field_delim = param_dict.get('field_delim')
  615. type_check(field_delim, (str,), 'field delim')
  616. if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
  617. raise ValueError("field_delim is not legal.")
  618. # check column_defaults
  619. column_defaults = param_dict.get('column_defaults')
  620. if column_defaults is not None:
  621. if not isinstance(column_defaults, list):
  622. raise TypeError("column_defaults should be type of list.")
  623. for item in column_defaults:
  624. if not isinstance(item, (str, int, float)):
  625. raise TypeError("column type is not legal.")
  626. # check column_names: must be list of string.
  627. column_names = param_dict.get("column_names")
  628. if column_names is not None:
  629. all_string = all(isinstance(item, str) for item in column_names)
  630. if not all_string:
  631. raise TypeError("column_names should be a list of str.")
  632. validate_dataset_param_value(nreq_param_int, param_dict, int)
  633. check_sampler_shuffle_shard_options(param_dict)
  634. cache = param_dict.get('cache')
  635. check_cache_option(cache)
  636. return method(self, *args, **kwargs)
  637. return new_method
  638. def check_textfiledataset(method):
  639. """A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
  640. @wraps(method)
  641. def new_method(self, *args, **kwargs):
  642. _, param_dict = parse_user_args(method, *args, **kwargs)
  643. nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
  644. dataset_files = param_dict.get('dataset_files')
  645. type_check(dataset_files, (str, list), "dataset files")
  646. validate_dataset_param_value(nreq_param_int, param_dict, int)
  647. check_sampler_shuffle_shard_options(param_dict)
  648. cache = param_dict.get('cache')
  649. check_cache_option(cache)
  650. return method(self, *args, **kwargs)
  651. return new_method
  652. def check_split(method):
  653. """check the input arguments of split."""
  654. @wraps(method)
  655. def new_method(self, *args, **kwargs):
  656. [sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
  657. type_check(sizes, (list,), "sizes")
  658. type_check(randomize, (bool,), "randomize")
  659. # check sizes: must be list of float or list of int
  660. if not sizes:
  661. raise ValueError("sizes cannot be empty.")
  662. all_int = all(isinstance(item, int) for item in sizes)
  663. all_float = all(isinstance(item, float) for item in sizes)
  664. if not (all_int or all_float):
  665. raise ValueError("sizes should be list of int or list of float.")
  666. if all_int:
  667. all_positive = all(item > 0 for item in sizes)
  668. if not all_positive:
  669. raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
  670. if all_float:
  671. all_valid_percentages = all(0 < item <= 1 for item in sizes)
  672. if not all_valid_percentages:
  673. raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
  674. epsilon = 0.00001
  675. if not abs(sum(sizes) - 1) < epsilon:
  676. raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
  677. return method(self, *args, **kwargs)
  678. return new_method
  679. def check_hostname(hostname):
  680. if not hostname or len(hostname) > 255:
  681. return False
  682. if hostname[-1] == ".":
  683. hostname = hostname[:-1] # strip exactly one dot from the right, if present
  684. allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
  685. return all(allowed.match(x) for x in hostname.split("."))
  686. def check_gnn_graphdata(method):
  687. """check the input arguments of graphdata."""
  688. @wraps(method)
  689. def new_method(self, *args, **kwargs):
  690. [dataset_file, num_parallel_workers, working_mode, hostname,
  691. port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
  692. check_file(dataset_file)
  693. if num_parallel_workers is not None:
  694. check_num_parallel_workers(num_parallel_workers)
  695. type_check(hostname, (str,), "hostname")
  696. if check_hostname(hostname) is False:
  697. raise ValueError("The hostname is illegal")
  698. type_check(working_mode, (str,), "working_mode")
  699. if working_mode not in {'local', 'client', 'server'}:
  700. raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'")
  701. type_check(port, (int,), "port")
  702. check_value(port, (1024, 65535), "port")
  703. type_check(num_client, (int,), "num_client")
  704. check_value(num_client, (1, 255), "num_client")
  705. type_check(auto_shutdown, (bool,), "auto_shutdown")
  706. return method(self, *args, **kwargs)
  707. return new_method
  708. def check_gnn_get_all_nodes(method):
  709. """A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""
  710. @wraps(method)
  711. def new_method(self, *args, **kwargs):
  712. [node_type], _ = parse_user_args(method, *args, **kwargs)
  713. type_check(node_type, (int,), "node_type")
  714. return method(self, *args, **kwargs)
  715. return new_method
  716. def check_gnn_get_all_edges(method):
  717. """A wrapper that wraps a parameter checker around the GNN `get_all_edges` function."""
  718. @wraps(method)
  719. def new_method(self, *args, **kwargs):
  720. [edge_type], _ = parse_user_args(method, *args, **kwargs)
  721. type_check(edge_type, (int,), "edge_type")
  722. return method(self, *args, **kwargs)
  723. return new_method
  724. def check_gnn_get_nodes_from_edges(method):
  725. """A wrapper that wraps a parameter checker around the GNN `get_nodes_from_edges` function."""
  726. @wraps(method)
  727. def new_method(self, *args, **kwargs):
  728. [edge_list], _ = parse_user_args(method, *args, **kwargs)
  729. check_gnn_list_or_ndarray(edge_list, "edge_list")
  730. return method(self, *args, **kwargs)
  731. return new_method
  732. def check_gnn_get_all_neighbors(method):
  733. """A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""
  734. @wraps(method)
  735. def new_method(self, *args, **kwargs):
  736. [node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs)
  737. check_gnn_list_or_ndarray(node_list, 'node_list')
  738. type_check(neighbour_type, (int,), "neighbour_type")
  739. return method(self, *args, **kwargs)
  740. return new_method
  741. def check_gnn_get_sampled_neighbors(method):
  742. """A wrapper that wraps a parameter checker around the GNN `get_sampled_neighbors` function."""
  743. @wraps(method)
  744. def new_method(self, *args, **kwargs):
  745. [node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs)
  746. check_gnn_list_or_ndarray(node_list, 'node_list')
  747. check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
  748. if not neighbor_nums or len(neighbor_nums) > 6:
  749. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
  750. 'neighbor_nums', len(neighbor_nums)))
  751. check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
  752. if not neighbor_types or len(neighbor_types) > 6:
  753. raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
  754. 'neighbor_types', len(neighbor_types)))
  755. if len(neighbor_nums) != len(neighbor_types):
  756. raise ValueError(
  757. "The number of members of neighbor_nums and neighbor_types is inconsistent")
  758. return method(self, *args, **kwargs)
  759. return new_method
  760. def check_gnn_get_neg_sampled_neighbors(method):
  761. """A wrapper that wraps a parameter checker around the GNN `get_neg_sampled_neighbors` function."""
  762. @wraps(method)
  763. def new_method(self, *args, **kwargs):
  764. [node_list, neg_neighbor_num, neg_neighbor_type], _ = parse_user_args(method, *args, **kwargs)
  765. check_gnn_list_or_ndarray(node_list, 'node_list')
  766. type_check(neg_neighbor_num, (int,), "neg_neighbor_num")
  767. type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
  768. return method(self, *args, **kwargs)
  769. return new_method
  770. def check_gnn_random_walk(method):
  771. """A wrapper that wraps a parameter checker around the GNN `random_walk` function."""
  772. @wraps(method)
  773. def new_method(self, *args, **kwargs):
  774. [target_nodes, meta_path, step_home_param, step_away_param, default_node], _ = parse_user_args(method, *args,
  775. **kwargs)
  776. check_gnn_list_or_ndarray(target_nodes, 'target_nodes')
  777. check_gnn_list_or_ndarray(meta_path, 'meta_path')
  778. type_check(step_home_param, (float,), "step_home_param")
  779. type_check(step_away_param, (float,), "step_away_param")
  780. type_check(default_node, (int,), "default_node")
  781. check_value(default_node, (-1, INT32_MAX), "default_node")
  782. return method(self, *args, **kwargs)
  783. return new_method
  784. def check_aligned_list(param, param_name, member_type):
  785. """Check whether the structure of each member of the list is the same."""
  786. type_check(param, (list,), "param")
  787. if not param:
  788. raise TypeError(
  789. "Parameter {0} or its members are empty".format(param_name))
  790. member_have_list = None
  791. list_len = None
  792. for member in param:
  793. if isinstance(member, list):
  794. check_aligned_list(member, param_name, member_type)
  795. if member_have_list not in (None, True):
  796. raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
  797. param_name))
  798. if list_len is not None and len(member) != list_len:
  799. raise TypeError("The size of each member of parameter {0} is inconsistent".format(
  800. param_name))
  801. member_have_list = True
  802. list_len = len(member)
  803. else:
  804. type_check(member, (member_type,), param_name)
  805. if member_have_list not in (None, False):
  806. raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
  807. param_name))
  808. member_have_list = False
  809. def check_gnn_get_node_feature(method):
  810. """A wrapper that wraps a parameter checker around the GNN `get_node_feature` function."""
  811. @wraps(method)
  812. def new_method(self, *args, **kwargs):
  813. [node_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
  814. type_check(node_list, (list, np.ndarray), "node_list")
  815. if isinstance(node_list, list):
  816. check_aligned_list(node_list, 'node_list', int)
  817. elif isinstance(node_list, np.ndarray):
  818. if not node_list.dtype == np.int32:
  819. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  820. node_list, node_list.dtype))
  821. check_gnn_list_or_ndarray(feature_types, 'feature_types')
  822. return method(self, *args, **kwargs)
  823. return new_method
  824. def check_gnn_get_edge_feature(method):
  825. """A wrapper that wraps a parameter checker around the GNN `get_edge_feature` function."""
  826. @wraps(method)
  827. def new_method(self, *args, **kwargs):
  828. [edge_list, feature_types], _ = parse_user_args(method, *args, **kwargs)
  829. type_check(edge_list, (list, np.ndarray), "edge_list")
  830. if isinstance(edge_list, list):
  831. check_aligned_list(edge_list, 'edge_list', int)
  832. elif isinstance(edge_list, np.ndarray):
  833. if not edge_list.dtype == np.int32:
  834. raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
  835. edge_list, edge_list.dtype))
  836. check_gnn_list_or_ndarray(feature_types, 'feature_types')
  837. return method(self, *args, **kwargs)
  838. return new_method
  839. def check_numpyslicesdataset(method):
  840. """A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
  841. @wraps(method)
  842. def new_method(self, *args, **kwargs):
  843. _, param_dict = parse_user_args(method, *args, **kwargs)
  844. data = param_dict.get("data")
  845. column_names = param_dict.get("column_names")
  846. if not data:
  847. raise ValueError("Argument data cannot be empty")
  848. type_check(data, (list, tuple, dict, np.ndarray), "data")
  849. if isinstance(data, tuple):
  850. type_check(data[0], (list, np.ndarray), "data[0]")
  851. # check column_names
  852. if column_names is not None:
  853. check_columns(column_names, "column_names")
  854. # check num of input column in column_names
  855. column_num = 1 if isinstance(column_names, str) else len(column_names)
  856. if isinstance(data, dict):
  857. data_column = len(list(data.keys()))
  858. if column_num != data_column:
  859. raise ValueError("Num of input column names is {0}, but required is {1}."
  860. .format(column_num, data_column))
  861. elif isinstance(data, tuple):
  862. if column_num != len(data):
  863. raise ValueError("Num of input column names is {0}, but required is {1}."
  864. .format(column_num, len(data)))
  865. else:
  866. if column_num != 1:
  867. raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
  868. .format(column_num, 1))
  869. return method(self, *args, **kwargs)
  870. return new_method
  871. def check_paddeddataset(method):
  872. """A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
  873. @wraps(method)
  874. def new_method(self, *args, **kwargs):
  875. _, param_dict = parse_user_args(method, *args, **kwargs)
  876. padded_samples = param_dict.get("padded_samples")
  877. if not padded_samples:
  878. raise ValueError("Argument padded_samples cannot be empty")
  879. type_check(padded_samples, (list,), "padded_samples")
  880. type_check(padded_samples[0], (dict,), "padded_element")
  881. return method(self, *args, **kwargs)
  882. return new_method
  883. def check_cache_option(cache):
  884. """Sanity check for cache parameter"""
  885. if cache is not None:
  886. if os.getenv('MS_ENABLE_CACHE') != 'TRUE':
  887. # temporary disable cache feature in the current release
  888. raise ValueError("Caching is disabled in the current release")
  889. from . import cache_client
  890. type_check(cache, (cache_client.DatasetCache,), "cache")