You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. import numbers
  18. from functools import wraps
  19. import numpy as np
  20. from mindspore._c_dataengine import TensorOp
  21. from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
  22. check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
  23. check_tensor_op, UINT8_MAX, check_value_normalize_std
  24. from .utils import Inter, Border, ImageBatchFormat
  25. def check_crop_size(size):
  26. """Wrapper method to check the parameters of crop size."""
  27. type_check(size, (int, list, tuple), "size")
  28. if isinstance(size, int):
  29. check_value(size, (1, FLOAT_MAX_INTEGER))
  30. elif isinstance(size, (tuple, list)) and len(size) == 2:
  31. for value in size:
  32. check_value(value, (1, FLOAT_MAX_INTEGER))
  33. else:
  34. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  35. def check_cut_mix_batch_c(method):
  36. """Wrapper method to check the parameters of CutMixBatch."""
  37. @wraps(method)
  38. def new_method(self, *args, **kwargs):
  39. [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs)
  40. type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format")
  41. check_pos_float32(alpha)
  42. check_positive(alpha, "alpha")
  43. check_value(prob, [0, 1], "prob")
  44. return method(self, *args, **kwargs)
  45. return new_method
  46. def check_resize_size(size):
  47. """Wrapper method to check the parameters of resize."""
  48. if isinstance(size, int):
  49. check_value(size, (1, FLOAT_MAX_INTEGER))
  50. elif isinstance(size, (tuple, list)) and len(size) == 2:
  51. for i, value in enumerate(size):
  52. check_value(value, (1, INT32_MAX), "size at dim {0}".format(i))
  53. else:
  54. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  55. def check_mix_up_batch_c(method):
  56. """Wrapper method to check the parameters of MixUpBatch."""
  57. @wraps(method)
  58. def new_method(self, *args, **kwargs):
  59. [alpha], _ = parse_user_args(method, *args, **kwargs)
  60. check_positive(alpha, "alpha")
  61. check_pos_float32(alpha)
  62. return method(self, *args, **kwargs)
  63. return new_method
  64. def check_normalize_c_param(mean, std):
  65. if len(mean) != len(std):
  66. raise ValueError("Length of mean and std must be equal")
  67. for mean_value in mean:
  68. check_pos_float32(mean_value)
  69. for std_value in std:
  70. check_pos_float32(std_value)
  71. def check_normalize_py_param(mean, std):
  72. if len(mean) != len(std):
  73. raise ValueError("Length of mean and std must be equal")
  74. for mean_value in mean:
  75. check_value(mean_value, [0., 1.], "mean_value")
  76. for std_value in std:
  77. check_value_normalize_std(std_value, [0., 1.], "std_value")
  78. def check_fill_value(fill_value):
  79. if isinstance(fill_value, int):
  80. check_uint8(fill_value)
  81. elif isinstance(fill_value, tuple) and len(fill_value) == 3:
  82. for value in fill_value:
  83. check_uint8(value)
  84. else:
  85. raise TypeError("fill_value should be a single integer or a 3-tuple.")
  86. def check_padding(padding):
  87. """Parsing the padding arguments and check if it is legal."""
  88. type_check(padding, (tuple, list, numbers.Number), "padding")
  89. if isinstance(padding, numbers.Number):
  90. check_value(padding, (0, INT32_MAX), "padding")
  91. if isinstance(padding, (tuple, list)):
  92. if len(padding) not in (2, 4):
  93. raise ValueError("The size of the padding list or tuple should be 2 or 4.")
  94. for i, pad_value in enumerate(padding):
  95. type_check(pad_value, (int,), "padding[{}]".format(i))
  96. check_value(pad_value, (0, INT32_MAX), "pad_value")
  97. def check_degrees(degrees):
  98. """Check if the degrees is legal."""
  99. type_check(degrees, (numbers.Number, list, tuple), "degrees")
  100. if isinstance(degrees, numbers.Number):
  101. check_value(degrees, (0, float("inf")), "degrees")
  102. elif isinstance(degrees, (list, tuple)):
  103. if len(degrees) == 2:
  104. type_check_list(degrees, (numbers.Number,), "degrees")
  105. if degrees[0] > degrees[1]:
  106. raise ValueError("degrees should be in (min,max) format. Got (max,min).")
  107. else:
  108. raise TypeError("If degrees is a sequence, the length must be 2.")
  109. def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
  110. """Check the parameters in random color adjust operation."""
  111. type_check(value, (numbers.Number, list, tuple), input_name)
  112. if isinstance(value, numbers.Number):
  113. if value < 0:
  114. raise ValueError("The input value of {} cannot be negative.".format(input_name))
  115. elif isinstance(value, (list, tuple)) and len(value) == 2:
  116. check_range(value, bound)
  117. if value[0] > value[1]:
  118. raise ValueError("value should be in (min,max) format. Got (max,min).")
  119. def check_erasing_value(value):
  120. if not (isinstance(value, (numbers.Number, str, bytes)) or
  121. (isinstance(value, (tuple, list)) and len(value) == 3)):
  122. raise ValueError("The value for erasing should be either a single value, "
  123. "or a string 'random', or a sequence of 3 elements for RGB respectively.")
  124. def check_crop(method):
  125. """A wrapper that wraps a parameter checker around the original function(crop operation)."""
  126. @wraps(method)
  127. def new_method(self, *args, **kwargs):
  128. [size], _ = parse_user_args(method, *args, **kwargs)
  129. check_crop_size(size)
  130. return method(self, *args, **kwargs)
  131. return new_method
  132. def check_posterize(method):
  133. """A wrapper that wraps a parameter checker around the original function(posterize operation)."""
  134. @wraps(method)
  135. def new_method(self, *args, **kwargs):
  136. [bits], _ = parse_user_args(method, *args, **kwargs)
  137. if bits is not None:
  138. type_check(bits, (list, tuple, int), "bits")
  139. if isinstance(bits, int):
  140. check_value(bits, [1, 8])
  141. if isinstance(bits, (list, tuple)):
  142. if len(bits) != 2:
  143. raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.")
  144. for item in bits:
  145. check_uint8(item, "bits")
  146. # also checks if min <= max
  147. check_range(bits, [1, 8])
  148. return method(self, *args, **kwargs)
  149. return new_method
  150. def check_resize_interpolation(method):
  151. """A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""
  152. @wraps(method)
  153. def new_method(self, *args, **kwargs):
  154. [size, interpolation], _ = parse_user_args(method, *args, **kwargs)
  155. check_resize_size(size)
  156. if interpolation is not None:
  157. type_check(interpolation, (Inter,), "interpolation")
  158. return method(self, *args, **kwargs)
  159. return new_method
  160. def check_resize(method):
  161. """A wrapper that wraps a parameter checker around the original function(resize operation)."""
  162. @wraps(method)
  163. def new_method(self, *args, **kwargs):
  164. [size], _ = parse_user_args(method, *args, **kwargs)
  165. check_resize_size(size)
  166. return method(self, *args, **kwargs)
  167. return new_method
  168. def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts):
  169. """Wrapper method to check the parameters of RandomCropDecodeResize and SoftDvppDecodeRandomCropResizeJpeg."""
  170. check_crop_size(size)
  171. if scale is not None:
  172. type_check(scale, (tuple,), "scale")
  173. type_check_list(scale, (float, int), "scale")
  174. check_range(scale, [0, FLOAT_MAX_INTEGER])
  175. if scale[0] > scale[1]:
  176. raise ValueError("scale should be in (min,max) format. Got (max,min).")
  177. if ratio is not None:
  178. type_check(ratio, (tuple,), "ratio")
  179. type_check_list(ratio, (float, int), "ratio")
  180. check_range(ratio, [0, FLOAT_MAX_INTEGER])
  181. if ratio[0] > ratio[1]:
  182. raise ValueError("ratio should be in (min,max) format. Got (max,min).")
  183. if max_attempts is not None:
  184. check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
  185. def check_random_resize_crop(method):
  186. """A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
  187. @wraps(method)
  188. def new_method(self, *args, **kwargs):
  189. [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  190. if interpolation is not None:
  191. type_check(interpolation, (Inter,), "interpolation")
  192. check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
  193. return method(self, *args, **kwargs)
  194. return new_method
  195. def check_prob(method):
  196. """A wrapper that wraps a parameter checker (to confirm probability) around the original function."""
  197. @wraps(method)
  198. def new_method(self, *args, **kwargs):
  199. [prob], _ = parse_user_args(method, *args, **kwargs)
  200. type_check(prob, (float, int,), "prob")
  201. check_value(prob, [0., 1.], "prob")
  202. return method(self, *args, **kwargs)
  203. return new_method
  204. def check_normalize_c(method):
  205. """A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
  206. @wraps(method)
  207. def new_method(self, *args, **kwargs):
  208. [mean, std], _ = parse_user_args(method, *args, **kwargs)
  209. check_normalize_c_param(mean, std)
  210. return method(self, *args, **kwargs)
  211. return new_method
  212. def check_normalize_py(method):
  213. """A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""
  214. @wraps(method)
  215. def new_method(self, *args, **kwargs):
  216. [mean, std], _ = parse_user_args(method, *args, **kwargs)
  217. check_normalize_py_param(mean, std)
  218. return method(self, *args, **kwargs)
  219. return new_method
  220. def check_random_crop(method):
  221. """Wrapper method to check the parameters of random crop."""
  222. @wraps(method)
  223. def new_method(self, *args, **kwargs):
  224. [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
  225. check_crop_size(size)
  226. type_check(pad_if_needed, (bool,), "pad_if_needed")
  227. if padding is not None:
  228. check_padding(padding)
  229. if fill_value is not None:
  230. check_fill_value(fill_value)
  231. if padding_mode is not None:
  232. type_check(padding_mode, (Border,), "padding_mode")
  233. return method(self, *args, **kwargs)
  234. return new_method
  235. def check_random_color_adjust(method):
  236. """Wrapper method to check the parameters of random color adjust."""
  237. @wraps(method)
  238. def new_method(self, *args, **kwargs):
  239. [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs)
  240. check_random_color_adjust_param(brightness, "brightness")
  241. check_random_color_adjust_param(contrast, "contrast")
  242. check_random_color_adjust_param(saturation, "saturation")
  243. check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  244. return method(self, *args, **kwargs)
  245. return new_method
  246. def check_random_rotation(method):
  247. """Wrapper method to check the parameters of random rotation."""
  248. @wraps(method)
  249. def new_method(self, *args, **kwargs):
  250. [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
  251. check_degrees(degrees)
  252. if resample is not None:
  253. type_check(resample, (Inter,), "resample")
  254. if expand is not None:
  255. type_check(expand, (bool,), "expand")
  256. if center is not None:
  257. check_2tuple(center, "center")
  258. if fill_value is not None:
  259. check_fill_value(fill_value)
  260. return method(self, *args, **kwargs)
  261. return new_method
  262. def check_ten_crop(method):
  263. """Wrapper method to check the parameters of crop."""
  264. @wraps(method)
  265. def new_method(self, *args, **kwargs):
  266. [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs)
  267. check_crop_size(size)
  268. if use_vertical_flip is not None:
  269. type_check(use_vertical_flip, (bool,), "use_vertical_flip")
  270. return method(self, *args, **kwargs)
  271. return new_method
  272. def check_num_channels(method):
  273. """Wrapper method to check the parameters of number of channels."""
  274. @wraps(method)
  275. def new_method(self, *args, **kwargs):
  276. [num_output_channels], _ = parse_user_args(method, *args, **kwargs)
  277. if num_output_channels is not None:
  278. if num_output_channels not in (1, 3):
  279. raise ValueError("Number of channels of the output grayscale image"
  280. "should be either 1 or 3. Got {0}".format(num_output_channels))
  281. return method(self, *args, **kwargs)
  282. return new_method
  283. def check_pad(method):
  284. """Wrapper method to check the parameters of random pad."""
  285. @wraps(method)
  286. def new_method(self, *args, **kwargs):
  287. [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
  288. check_padding(padding)
  289. check_fill_value(fill_value)
  290. type_check(padding_mode, (Border,), "padding_mode")
  291. return method(self, *args, **kwargs)
  292. return new_method
  293. def check_random_perspective(method):
  294. """Wrapper method to check the parameters of random perspective."""
  295. @wraps(method)
  296. def new_method(self, *args, **kwargs):
  297. [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs)
  298. check_value(distortion_scale, [0., 1.], "distortion_scale")
  299. check_value(prob, [0., 1.], "prob")
  300. type_check(interpolation, (Inter,), "interpolation")
  301. return method(self, *args, **kwargs)
  302. return new_method
  303. def check_mix_up(method):
  304. """Wrapper method to check the parameters of mix up."""
  305. @wraps(method)
  306. def new_method(self, *args, **kwargs):
  307. [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs)
  308. check_value(batch_size, (1, FLOAT_MAX_INTEGER))
  309. check_positive(alpha, "alpha")
  310. type_check(is_single, (bool,), "is_single")
  311. return method(self, *args, **kwargs)
  312. return new_method
  313. def check_random_erasing(method):
  314. """Wrapper method to check the parameters of random erasing."""
  315. @wraps(method)
  316. def new_method(self, *args, **kwargs):
  317. [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  318. check_value(prob, [0., 1.], "prob")
  319. check_range(scale, [0, FLOAT_MAX_INTEGER])
  320. check_range(ratio, [0, FLOAT_MAX_INTEGER])
  321. check_erasing_value(value)
  322. type_check(inplace, (bool,), "inplace")
  323. check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
  324. return method(self, *args, **kwargs)
  325. return new_method
  326. def check_cutout(method):
  327. """Wrapper method to check the parameters of cutout operation."""
  328. @wraps(method)
  329. def new_method(self, *args, **kwargs):
  330. [length, num_patches], _ = parse_user_args(method, *args, **kwargs)
  331. check_value(length, (1, FLOAT_MAX_INTEGER))
  332. check_value(num_patches, (1, FLOAT_MAX_INTEGER))
  333. return method(self, *args, **kwargs)
  334. return new_method
  335. def check_linear_transform(method):
  336. """Wrapper method to check the parameters of linear transform."""
  337. @wraps(method)
  338. def new_method(self, *args, **kwargs):
  339. [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs)
  340. type_check(transformation_matrix, (np.ndarray,), "transformation_matrix")
  341. type_check(mean_vector, (np.ndarray,), "mean_vector")
  342. if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
  343. raise ValueError("transformation_matrix should be a square matrix. "
  344. "Got shape {} instead".format(transformation_matrix.shape))
  345. if mean_vector.shape[0] != transformation_matrix.shape[0]:
  346. raise ValueError("mean_vector length {0} should match either one dimension of the square"
  347. "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
  348. return method(self, *args, **kwargs)
  349. return new_method
  350. def check_random_affine(method):
  351. """Wrapper method to check the parameters of random affine."""
  352. @wraps(method)
  353. def new_method(self, *args, **kwargs):
  354. [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs)
  355. check_degrees(degrees)
  356. if translate is not None:
  357. type_check(translate, (list, tuple), "translate")
  358. type_check_list(translate, (int, float), "translate")
  359. if len(translate) != 2 and len(translate) != 4:
  360. raise TypeError("translate should be a list or tuple of length 2 or 4.")
  361. for i, t in enumerate(translate):
  362. check_value(t, [-1.0, 1.0], "translate at {0}".format(i))
  363. if scale is not None:
  364. type_check(scale, (tuple, list), "scale")
  365. type_check_list(scale, (int, float), "scale")
  366. if len(scale) == 2:
  367. for i, s in enumerate(scale):
  368. check_positive(s, "scale[{}]".format(i))
  369. if scale[0] > scale[1]:
  370. raise ValueError("Input scale[1] must be equal to or greater than scale[0].")
  371. else:
  372. raise TypeError("scale should be a list or tuple of length 2.")
  373. if shear is not None:
  374. type_check(shear, (numbers.Number, tuple, list), "shear")
  375. if isinstance(shear, numbers.Number):
  376. check_positive(shear, "shear")
  377. else:
  378. type_check_list(shear, (int, float), "shear")
  379. if len(shear) not in (2, 4):
  380. raise TypeError("shear must be of length 2 or 4.")
  381. if len(shear) == 2 and shear[0] > shear[1]:
  382. raise ValueError("Input shear[1] must be equal to or greater than shear[0]")
  383. if len(shear) == 4 and (shear[0] > shear[1] or shear[2] > shear[3]):
  384. raise ValueError("Input shear[1] must be equal to or greater than shear[0] and "
  385. "shear[3] must be equal to or greater than shear[2].")
  386. type_check(resample, (Inter,), "resample")
  387. if fill_value is not None:
  388. check_fill_value(fill_value)
  389. return method(self, *args, **kwargs)
  390. return new_method
  391. def check_rescale(method):
  392. """Wrapper method to check the parameters of rescale."""
  393. @wraps(method)
  394. def new_method(self, *args, **kwargs):
  395. [rescale, shift], _ = parse_user_args(method, *args, **kwargs)
  396. check_pos_float32(rescale)
  397. type_check(shift, (numbers.Number,), "shift")
  398. return method(self, *args, **kwargs)
  399. return new_method
  400. def check_uniform_augment_cpp(method):
  401. """Wrapper method to check the parameters of UniformAugment C++ op."""
  402. @wraps(method)
  403. def new_method(self, *args, **kwargs):
  404. [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
  405. type_check(num_ops, (int,), "num_ops")
  406. check_positive(num_ops, "num_ops")
  407. if num_ops > len(transforms):
  408. raise ValueError("num_ops is greater than transforms list size")
  409. type_check_list(transforms, (TensorOp,), "tensor_ops")
  410. return method(self, *args, **kwargs)
  411. return new_method
  412. def check_bounding_box_augment_cpp(method):
  413. """Wrapper method to check the parameters of BoundingBoxAugment C++ op."""
  414. @wraps(method)
  415. def new_method(self, *args, **kwargs):
  416. [transform, ratio], _ = parse_user_args(method, *args, **kwargs)
  417. type_check(ratio, (float, int), "ratio")
  418. check_value(ratio, [0., 1.], "ratio")
  419. type_check(transform, (TensorOp,), "transform")
  420. return method(self, *args, **kwargs)
  421. return new_method
  422. def check_auto_contrast(method):
  423. """Wrapper method to check the parameters of AutoContrast ops (Python and C++)."""
  424. @wraps(method)
  425. def new_method(self, *args, **kwargs):
  426. [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
  427. type_check(cutoff, (int, float), "cutoff")
  428. check_value(cutoff, [0, 100], "cutoff")
  429. if ignore is not None:
  430. type_check(ignore, (list, tuple, int), "ignore")
  431. if isinstance(ignore, int):
  432. check_value(ignore, [0, 255], "ignore")
  433. if isinstance(ignore, (list, tuple)):
  434. for item in ignore:
  435. type_check(item, (int,), "item")
  436. check_value(item, [0, 255], "ignore")
  437. return method(self, *args, **kwargs)
  438. return new_method
  439. def check_uniform_augment_py(method):
  440. """Wrapper method to check the parameters of Python UniformAugment op."""
  441. @wraps(method)
  442. def new_method(self, *args, **kwargs):
  443. [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
  444. type_check(transforms, (list,), "transforms")
  445. if not transforms:
  446. raise ValueError("transforms list is empty.")
  447. for transform in transforms:
  448. if isinstance(transform, TensorOp):
  449. raise ValueError("transform list only accepts Python operations.")
  450. type_check(num_ops, (int,), "num_ops")
  451. check_positive(num_ops, "num_ops")
  452. if num_ops > len(transforms):
  453. raise ValueError("num_ops cannot be greater than the length of transforms list.")
  454. return method(self, *args, **kwargs)
  455. return new_method
  456. def check_positive_degrees(method):
  457. """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (Python and C++)"""
  458. @wraps(method)
  459. def new_method(self, *args, **kwargs):
  460. [degrees], _ = parse_user_args(method, *args, **kwargs)
  461. if degrees is not None:
  462. if not isinstance(degrees, (list, tuple)):
  463. raise TypeError("degrees must be either a tuple or a list.")
  464. type_check_list(degrees, (int, float), "degrees")
  465. if len(degrees) != 2:
  466. raise ValueError("degrees must be a sequence with length 2.")
  467. for degree in degrees:
  468. check_value(degree, (0, FLOAT_MAX_INTEGER))
  469. if degrees[0] > degrees[1]:
  470. raise ValueError("degrees should be in (min,max) format. Got (max,min).")
  471. return method(self, *args, **kwargs)
  472. return new_method
  473. def check_random_select_subpolicy_op(method):
  474. """Wrapper method to check the parameters of RandomSelectSubpolicyOp."""
  475. @wraps(method)
  476. def new_method(self, *args, **kwargs):
  477. [policy], _ = parse_user_args(method, *args, **kwargs)
  478. type_check(policy, (list,), "policy")
  479. if not policy:
  480. raise ValueError("policy can not be empty.")
  481. for sub_ind, sub in enumerate(policy):
  482. type_check(sub, (list,), "policy[{0}]".format([sub_ind]))
  483. if not sub:
  484. raise ValueError("policy[{0}] can not be empty.".format(sub_ind))
  485. for op_ind, tp in enumerate(sub):
  486. check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind))
  487. check_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
  488. check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind))
  489. return method(self, *args, **kwargs)
  490. return new_method
  491. def check_soft_dvpp_decode_random_crop_resize_jpeg(method):
  492. """Wrapper method to check the parameters of SoftDvppDecodeRandomCropResizeJpeg."""
  493. @wraps(method)
  494. def new_method(self, *args, **kwargs):
  495. [size, scale, ratio, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  496. check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
  497. return method(self, *args, **kwargs)
  498. return new_method
  499. def check_random_solarize(method):
  500. """Wrapper method to check the parameters of RandomSolarizeOp."""
  501. @wraps(method)
  502. def new_method(self, *args, **kwargs):
  503. [threshold], _ = parse_user_args(method, *args, **kwargs)
  504. type_check(threshold, (tuple,), "threshold")
  505. type_check_list(threshold, (int,), "threshold")
  506. if len(threshold) != 2:
  507. raise ValueError("threshold must be a sequence of two numbers")
  508. for element in threshold:
  509. check_value(element, (0, UINT8_MAX))
  510. if threshold[1] < threshold[0]:
  511. raise ValueError("threshold must be in min max format numbers")
  512. return method(self, *args, **kwargs)
  513. return new_method