You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 29 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. import numbers
  18. from functools import wraps
  19. import numpy as np
  20. from mindspore._c_dataengine import TensorOp, TensorOperation
  21. from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
  22. check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
  23. check_c_tensor_op, UINT8_MAX, check_value_normalize_std, check_value_cutoff
  24. from .utils import Inter, Border, ImageBatchFormat
  25. def check_crop_size(size):
  26. """Wrapper method to check the parameters of crop size."""
  27. type_check(size, (int, list, tuple), "size")
  28. if isinstance(size, int):
  29. check_value(size, (1, FLOAT_MAX_INTEGER))
  30. elif isinstance(size, (tuple, list)) and len(size) == 2:
  31. for value in size:
  32. check_value(value, (1, FLOAT_MAX_INTEGER))
  33. else:
  34. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  35. def check_cut_mix_batch_c(method):
  36. """Wrapper method to check the parameters of CutMixBatch."""
  37. @wraps(method)
  38. def new_method(self, *args, **kwargs):
  39. [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs)
  40. type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format")
  41. check_pos_float32(alpha)
  42. check_positive(alpha, "alpha")
  43. check_value(prob, [0, 1], "prob")
  44. return method(self, *args, **kwargs)
  45. return new_method
  46. def check_resize_size(size):
  47. """Wrapper method to check the parameters of resize."""
  48. if isinstance(size, int):
  49. check_value(size, (1, FLOAT_MAX_INTEGER))
  50. elif isinstance(size, (tuple, list)) and len(size) == 2:
  51. for i, value in enumerate(size):
  52. type_check(value, (int,), "size at dim {0}".format(i))
  53. check_value(value, (1, INT32_MAX), "size at dim {0}".format(i))
  54. else:
  55. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  56. def check_mix_up_batch_c(method):
  57. """Wrapper method to check the parameters of MixUpBatch."""
  58. @wraps(method)
  59. def new_method(self, *args, **kwargs):
  60. [alpha], _ = parse_user_args(method, *args, **kwargs)
  61. check_positive(alpha, "alpha")
  62. check_pos_float32(alpha)
  63. return method(self, *args, **kwargs)
  64. return new_method
  65. def check_normalize_c_param(mean, std):
  66. type_check(mean, (list, tuple), "mean")
  67. type_check(std, (list, tuple), "std")
  68. if len(mean) != len(std):
  69. raise ValueError("Length of mean and std must be equal.")
  70. for mean_value in mean:
  71. check_value(mean_value, [0, 255], "mean_value")
  72. for std_value in std:
  73. check_value_normalize_std(std_value, [0, 255], "std_value")
  74. def check_normalize_py_param(mean, std):
  75. if len(mean) != len(std):
  76. raise ValueError("Length of mean and std must be equal.")
  77. for mean_value in mean:
  78. check_value(mean_value, [0., 1.], "mean_value")
  79. for std_value in std:
  80. check_value_normalize_std(std_value, [0., 1.], "std_value")
  81. def check_fill_value(fill_value):
  82. if isinstance(fill_value, int):
  83. check_uint8(fill_value)
  84. elif isinstance(fill_value, tuple) and len(fill_value) == 3:
  85. for value in fill_value:
  86. check_uint8(value)
  87. else:
  88. raise TypeError("fill_value should be a single integer or a 3-tuple.")
  89. def check_padding(padding):
  90. """Parsing the padding arguments and check if it is legal."""
  91. type_check(padding, (tuple, list, numbers.Number), "padding")
  92. if isinstance(padding, numbers.Number):
  93. check_value(padding, (0, INT32_MAX), "padding")
  94. if isinstance(padding, (tuple, list)):
  95. if len(padding) not in (2, 4):
  96. raise ValueError("The size of the padding list or tuple should be 2 or 4.")
  97. for i, pad_value in enumerate(padding):
  98. type_check(pad_value, (int,), "padding[{}]".format(i))
  99. check_value(pad_value, (0, INT32_MAX), "pad_value")
  100. def check_degrees(degrees):
  101. """Check if the degrees is legal."""
  102. type_check(degrees, (numbers.Number, list, tuple), "degrees")
  103. if isinstance(degrees, numbers.Number):
  104. check_value(degrees, (0, float("inf")), "degrees")
  105. elif isinstance(degrees, (list, tuple)):
  106. if len(degrees) == 2:
  107. type_check_list(degrees, (numbers.Number,), "degrees")
  108. if degrees[0] > degrees[1]:
  109. raise ValueError("degrees should be in (min,max) format. Got (max,min).")
  110. else:
  111. raise TypeError("If degrees is a sequence, the length must be 2.")
  112. def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
  113. """Check the parameters in random color adjust operation."""
  114. type_check(value, (numbers.Number, list, tuple), input_name)
  115. if isinstance(value, numbers.Number):
  116. if value < 0:
  117. raise ValueError("The input value of {} cannot be negative.".format(input_name))
  118. elif isinstance(value, (list, tuple)):
  119. if len(value) != 2:
  120. raise TypeError("If {0} is a sequence, the length must be 2.".format(input_name))
  121. if value[0] > value[1]:
  122. raise ValueError("{0} value should be in (min,max) format. Got ({1}, {2}).".format(input_name,
  123. value[0], value[1]))
  124. check_range(value, bound)
  125. def check_erasing_value(value):
  126. if not (isinstance(value, (numbers.Number, str, bytes)) or
  127. (isinstance(value, (tuple, list)) and len(value) == 3)):
  128. raise ValueError("The value for erasing should be either a single value, "
  129. "or a string 'random', or a sequence of 3 elements for RGB respectively.")
  130. def check_crop(method):
  131. """A wrapper that wraps a parameter checker around the original function(crop operation)."""
  132. @wraps(method)
  133. def new_method(self, *args, **kwargs):
  134. [size], _ = parse_user_args(method, *args, **kwargs)
  135. check_crop_size(size)
  136. return method(self, *args, **kwargs)
  137. return new_method
  138. def check_posterize(method):
  139. """A wrapper that wraps a parameter checker around the original function(posterize operation)."""
  140. @wraps(method)
  141. def new_method(self, *args, **kwargs):
  142. [bits], _ = parse_user_args(method, *args, **kwargs)
  143. if bits is not None:
  144. type_check(bits, (list, tuple, int), "bits")
  145. if isinstance(bits, int):
  146. check_value(bits, [1, 8])
  147. if isinstance(bits, (list, tuple)):
  148. if len(bits) != 2:
  149. raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.")
  150. for item in bits:
  151. check_uint8(item, "bits")
  152. # also checks if min <= max
  153. check_range(bits, [1, 8])
  154. return method(self, *args, **kwargs)
  155. return new_method
  156. def check_resize_interpolation(method):
  157. """A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""
  158. @wraps(method)
  159. def new_method(self, *args, **kwargs):
  160. [size, interpolation], _ = parse_user_args(method, *args, **kwargs)
  161. if interpolation is None:
  162. raise KeyError("Interpolation should not be None")
  163. check_resize_size(size)
  164. type_check(interpolation, (Inter,), "interpolation")
  165. return method(self, *args, **kwargs)
  166. return new_method
  167. def check_resize(method):
  168. """A wrapper that wraps a parameter checker around the original function(resize operation)."""
  169. @wraps(method)
  170. def new_method(self, *args, **kwargs):
  171. [size], _ = parse_user_args(method, *args, **kwargs)
  172. check_resize_size(size)
  173. return method(self, *args, **kwargs)
  174. return new_method
  175. def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts):
  176. """Wrapper method to check the parameters of RandomCropDecodeResize and SoftDvppDecodeRandomCropResizeJpeg."""
  177. check_crop_size(size)
  178. if scale is not None:
  179. type_check(scale, (tuple,), "scale")
  180. type_check_list(scale, (float, int), "scale")
  181. if scale[0] > scale[1]:
  182. raise ValueError("scale should be in (min,max) format. Got (max,min).")
  183. check_range(scale, [0, FLOAT_MAX_INTEGER])
  184. check_positive(scale[1], "scale[1]")
  185. if ratio is not None:
  186. type_check(ratio, (tuple,), "ratio")
  187. type_check_list(ratio, (float, int), "ratio")
  188. if ratio[0] > ratio[1]:
  189. raise ValueError("ratio should be in (min,max) format. Got (max,min).")
  190. check_range(ratio, [0, FLOAT_MAX_INTEGER])
  191. check_positive(ratio[0], "ratio[0]")
  192. check_positive(ratio[1], "ratio[1]")
  193. if max_attempts is not None:
  194. check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
  195. def check_random_resize_crop(method):
  196. """A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
  197. @wraps(method)
  198. def new_method(self, *args, **kwargs):
  199. [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  200. if interpolation is not None:
  201. type_check(interpolation, (Inter,), "interpolation")
  202. check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
  203. return method(self, *args, **kwargs)
  204. return new_method
  205. def check_prob(method):
  206. """A wrapper that wraps a parameter checker (to confirm probability) around the original function."""
  207. @wraps(method)
  208. def new_method(self, *args, **kwargs):
  209. [prob], _ = parse_user_args(method, *args, **kwargs)
  210. type_check(prob, (float, int,), "prob")
  211. check_value(prob, [0., 1.], "prob")
  212. return method(self, *args, **kwargs)
  213. return new_method
  214. def check_normalize_c(method):
  215. """A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
  216. @wraps(method)
  217. def new_method(self, *args, **kwargs):
  218. [mean, std], _ = parse_user_args(method, *args, **kwargs)
  219. check_normalize_c_param(mean, std)
  220. return method(self, *args, **kwargs)
  221. return new_method
  222. def check_normalize_py(method):
  223. """A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""
  224. @wraps(method)
  225. def new_method(self, *args, **kwargs):
  226. [mean, std], _ = parse_user_args(method, *args, **kwargs)
  227. check_normalize_py_param(mean, std)
  228. return method(self, *args, **kwargs)
  229. return new_method
  230. def check_normalizepad_c(method):
  231. """A wrapper that wraps a parameter checker around the original function(normalizepad operation written in C++)."""
  232. @wraps(method)
  233. def new_method(self, *args, **kwargs):
  234. [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
  235. check_normalize_c_param(mean, std)
  236. if not isinstance(dtype, str):
  237. raise TypeError("dtype should be string.")
  238. if dtype not in ["float32", "float16"]:
  239. raise ValueError("dtype only support float32 or float16.")
  240. return method(self, *args, **kwargs)
  241. return new_method
  242. def check_normalizepad_py(method):
  243. """A wrapper that wraps a parameter checker around the original function(normalizepad operation written in Python)."""
  244. @wraps(method)
  245. def new_method(self, *args, **kwargs):
  246. [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
  247. check_normalize_py_param(mean, std)
  248. if not isinstance(dtype, str):
  249. raise TypeError("dtype should be string.")
  250. if dtype not in ["float32", "float16"]:
  251. raise ValueError("dtype only support float32 or float16.")
  252. return method(self, *args, **kwargs)
  253. return new_method
  254. def check_random_crop(method):
  255. """Wrapper method to check the parameters of random crop."""
  256. @wraps(method)
  257. def new_method(self, *args, **kwargs):
  258. [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
  259. check_crop_size(size)
  260. type_check(pad_if_needed, (bool,), "pad_if_needed")
  261. if padding is not None:
  262. check_padding(padding)
  263. if fill_value is not None:
  264. check_fill_value(fill_value)
  265. if padding_mode is not None:
  266. type_check(padding_mode, (Border,), "padding_mode")
  267. return method(self, *args, **kwargs)
  268. return new_method
  269. def check_random_color_adjust(method):
  270. """Wrapper method to check the parameters of random color adjust."""
  271. @wraps(method)
  272. def new_method(self, *args, **kwargs):
  273. [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs)
  274. check_random_color_adjust_param(brightness, "brightness")
  275. check_random_color_adjust_param(contrast, "contrast")
  276. check_random_color_adjust_param(saturation, "saturation")
  277. check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  278. return method(self, *args, **kwargs)
  279. return new_method
  280. def check_random_rotation(method):
  281. """Wrapper method to check the parameters of random rotation."""
  282. @wraps(method)
  283. def new_method(self, *args, **kwargs):
  284. [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
  285. check_degrees(degrees)
  286. if resample is not None:
  287. type_check(resample, (Inter,), "resample")
  288. if expand is not None:
  289. type_check(expand, (bool,), "expand")
  290. if center is not None:
  291. check_2tuple(center, "center")
  292. if fill_value is not None:
  293. check_fill_value(fill_value)
  294. return method(self, *args, **kwargs)
  295. return new_method
  296. def check_ten_crop(method):
  297. """Wrapper method to check the parameters of crop."""
  298. @wraps(method)
  299. def new_method(self, *args, **kwargs):
  300. [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs)
  301. check_crop_size(size)
  302. if use_vertical_flip is not None:
  303. type_check(use_vertical_flip, (bool,), "use_vertical_flip")
  304. return method(self, *args, **kwargs)
  305. return new_method
  306. def check_num_channels(method):
  307. """Wrapper method to check the parameters of number of channels."""
  308. @wraps(method)
  309. def new_method(self, *args, **kwargs):
  310. [num_output_channels], _ = parse_user_args(method, *args, **kwargs)
  311. if num_output_channels is not None:
  312. if num_output_channels not in (1, 3):
  313. raise ValueError("Number of channels of the output grayscale image"
  314. "should be either 1 or 3. Got {0}.".format(num_output_channels))
  315. return method(self, *args, **kwargs)
  316. return new_method
  317. def check_pad(method):
  318. """Wrapper method to check the parameters of random pad."""
  319. @wraps(method)
  320. def new_method(self, *args, **kwargs):
  321. [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
  322. check_padding(padding)
  323. check_fill_value(fill_value)
  324. type_check(padding_mode, (Border,), "padding_mode")
  325. return method(self, *args, **kwargs)
  326. return new_method
  327. def check_random_perspective(method):
  328. """Wrapper method to check the parameters of random perspective."""
  329. @wraps(method)
  330. def new_method(self, *args, **kwargs):
  331. [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs)
  332. check_value(distortion_scale, [0., 1.], "distortion_scale")
  333. check_value(prob, [0., 1.], "prob")
  334. type_check(interpolation, (Inter,), "interpolation")
  335. return method(self, *args, **kwargs)
  336. return new_method
  337. def check_mix_up(method):
  338. """Wrapper method to check the parameters of mix up."""
  339. @wraps(method)
  340. def new_method(self, *args, **kwargs):
  341. [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs)
  342. check_value(batch_size, (1, FLOAT_MAX_INTEGER))
  343. check_positive(alpha, "alpha")
  344. type_check(is_single, (bool,), "is_single")
  345. return method(self, *args, **kwargs)
  346. return new_method
  347. def check_random_erasing(method):
  348. """Wrapper method to check the parameters of random erasing."""
  349. @wraps(method)
  350. def new_method(self, *args, **kwargs):
  351. [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  352. check_value(prob, [0., 1.], "prob")
  353. if scale[0] > scale[1]:
  354. raise ValueError("scale should be in (min,max) format. Got (max,min).")
  355. check_range(scale, [0, FLOAT_MAX_INTEGER])
  356. check_positive(scale[1], "scale[1]")
  357. if ratio[0] > ratio[1]:
  358. raise ValueError("ratio should be in (min,max) format. Got (max,min).")
  359. check_range(ratio, [0, FLOAT_MAX_INTEGER])
  360. check_positive(ratio[0], "ratio[0]")
  361. check_positive(ratio[1], "ratio[1]")
  362. check_erasing_value(value)
  363. type_check(inplace, (bool,), "inplace")
  364. check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
  365. return method(self, *args, **kwargs)
  366. return new_method
  367. def check_cutout(method):
  368. """Wrapper method to check the parameters of cutout operation."""
  369. @wraps(method)
  370. def new_method(self, *args, **kwargs):
  371. [length, num_patches], _ = parse_user_args(method, *args, **kwargs)
  372. type_check(length, (int,), "length")
  373. type_check(num_patches, (int,), "num_patches")
  374. check_value(length, (1, FLOAT_MAX_INTEGER))
  375. check_value(num_patches, (1, FLOAT_MAX_INTEGER))
  376. return method(self, *args, **kwargs)
  377. return new_method
  378. def check_linear_transform(method):
  379. """Wrapper method to check the parameters of linear transform."""
  380. @wraps(method)
  381. def new_method(self, *args, **kwargs):
  382. [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs)
  383. type_check(transformation_matrix, (np.ndarray,), "transformation_matrix")
  384. type_check(mean_vector, (np.ndarray,), "mean_vector")
  385. if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
  386. raise ValueError("transformation_matrix should be a square matrix. "
  387. "Got shape {} instead.".format(transformation_matrix.shape))
  388. if mean_vector.shape[0] != transformation_matrix.shape[0]:
  389. raise ValueError("mean_vector length {0} should match either one dimension of the square"
  390. "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
  391. return method(self, *args, **kwargs)
  392. return new_method
  393. def check_random_affine(method):
  394. """Wrapper method to check the parameters of random affine."""
  395. @wraps(method)
  396. def new_method(self, *args, **kwargs):
  397. [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs)
  398. check_degrees(degrees)
  399. if translate is not None:
  400. type_check(translate, (list, tuple), "translate")
  401. type_check_list(translate, (int, float), "translate")
  402. if len(translate) != 2 and len(translate) != 4:
  403. raise TypeError("translate should be a list or tuple of length 2 or 4.")
  404. for i, t in enumerate(translate):
  405. check_value(t, [-1.0, 1.0], "translate at {0}".format(i))
  406. if scale is not None:
  407. type_check(scale, (tuple, list), "scale")
  408. type_check_list(scale, (int, float), "scale")
  409. if len(scale) == 2:
  410. if scale[0] > scale[1]:
  411. raise ValueError("Input scale[1] must be equal to or greater than scale[0].")
  412. check_range(scale, [0, FLOAT_MAX_INTEGER])
  413. check_positive(scale[1], "scale[1]")
  414. else:
  415. raise TypeError("scale should be a list or tuple of length 2.")
  416. if shear is not None:
  417. type_check(shear, (numbers.Number, tuple, list), "shear")
  418. if isinstance(shear, numbers.Number):
  419. check_positive(shear, "shear")
  420. else:
  421. type_check_list(shear, (int, float), "shear")
  422. if len(shear) not in (2, 4):
  423. raise TypeError("shear must be of length 2 or 4.")
  424. if len(shear) == 2 and shear[0] > shear[1]:
  425. raise ValueError("Input shear[1] must be equal to or greater than shear[0]")
  426. if len(shear) == 4 and (shear[0] > shear[1] or shear[2] > shear[3]):
  427. raise ValueError("Input shear[1] must be equal to or greater than shear[0] and "
  428. "shear[3] must be equal to or greater than shear[2].")
  429. type_check(resample, (Inter,), "resample")
  430. if fill_value is not None:
  431. check_fill_value(fill_value)
  432. return method(self, *args, **kwargs)
  433. return new_method
  434. def check_rescale(method):
  435. """Wrapper method to check the parameters of rescale."""
  436. @wraps(method)
  437. def new_method(self, *args, **kwargs):
  438. [rescale, shift], _ = parse_user_args(method, *args, **kwargs)
  439. type_check(rescale, (numbers.Number,), "rescale")
  440. type_check(shift, (numbers.Number,), "shift")
  441. check_float32(rescale)
  442. check_float32(shift)
  443. return method(self, *args, **kwargs)
  444. return new_method
  445. def check_uniform_augment_cpp(method):
  446. """Wrapper method to check the parameters of UniformAugment C++ op."""
  447. @wraps(method)
  448. def new_method(self, *args, **kwargs):
  449. [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
  450. type_check(num_ops, (int,), "num_ops")
  451. check_positive(num_ops, "num_ops")
  452. if num_ops > len(transforms):
  453. raise ValueError("num_ops is greater than transforms list size.")
  454. parsed_transforms = []
  455. for op in transforms:
  456. if op and getattr(op, 'parse', None):
  457. parsed_transforms.append(op.parse())
  458. else:
  459. parsed_transforms.append(op)
  460. type_check(parsed_transforms, (list, tuple,), "transforms")
  461. for index, arg in enumerate(parsed_transforms):
  462. if not isinstance(arg, (TensorOp, TensorOperation)):
  463. raise TypeError("Type of Transforms[{0}] must be c_transform, but got {1}".format(index, type(arg)))
  464. return method(self, *args, **kwargs)
  465. return new_method
  466. def check_bounding_box_augment_cpp(method):
  467. """Wrapper method to check the parameters of BoundingBoxAugment C++ op."""
  468. @wraps(method)
  469. def new_method(self, *args, **kwargs):
  470. [transform, ratio], _ = parse_user_args(method, *args, **kwargs)
  471. type_check(ratio, (float, int), "ratio")
  472. check_value(ratio, [0., 1.], "ratio")
  473. if transform and getattr(transform, 'parse', None):
  474. transform = transform.parse()
  475. type_check(transform, (TensorOp, TensorOperation), "transform")
  476. return method(self, *args, **kwargs)
  477. return new_method
  478. def check_auto_contrast(method):
  479. """Wrapper method to check the parameters of AutoContrast ops (Python and C++)."""
  480. @wraps(method)
  481. def new_method(self, *args, **kwargs):
  482. [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
  483. type_check(cutoff, (int, float), "cutoff")
  484. check_value_cutoff(cutoff, [0, 50], "cutoff")
  485. if ignore is not None:
  486. type_check(ignore, (list, tuple, int), "ignore")
  487. if isinstance(ignore, int):
  488. check_value(ignore, [0, 255], "ignore")
  489. if isinstance(ignore, (list, tuple)):
  490. for item in ignore:
  491. type_check(item, (int,), "item")
  492. check_value(item, [0, 255], "ignore")
  493. return method(self, *args, **kwargs)
  494. return new_method
  495. def check_uniform_augment_py(method):
  496. """Wrapper method to check the parameters of Python UniformAugment op."""
  497. @wraps(method)
  498. def new_method(self, *args, **kwargs):
  499. [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
  500. type_check(transforms, (list,), "transforms")
  501. if not transforms:
  502. raise ValueError("transforms list is empty.")
  503. for transform in transforms:
  504. if isinstance(transform, TensorOp):
  505. raise ValueError("transform list only accepts Python operations.")
  506. type_check(num_ops, (int,), "num_ops")
  507. check_positive(num_ops, "num_ops")
  508. if num_ops > len(transforms):
  509. raise ValueError("num_ops cannot be greater than the length of transforms list.")
  510. return method(self, *args, **kwargs)
  511. return new_method
  512. def check_positive_degrees(method):
  513. """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (Python and C++)"""
  514. @wraps(method)
  515. def new_method(self, *args, **kwargs):
  516. [degrees], _ = parse_user_args(method, *args, **kwargs)
  517. if degrees is not None:
  518. if not isinstance(degrees, (list, tuple)):
  519. raise TypeError("degrees must be either a tuple or a list.")
  520. type_check_list(degrees, (int, float), "degrees")
  521. if len(degrees) != 2:
  522. raise ValueError("degrees must be a sequence with length 2.")
  523. for degree in degrees:
  524. check_value(degree, (0, FLOAT_MAX_INTEGER))
  525. if degrees[0] > degrees[1]:
  526. raise ValueError("degrees should be in (min,max) format. Got (max,min).")
  527. return method(self, *args, **kwargs)
  528. return new_method
  529. def check_random_select_subpolicy_op(method):
  530. """Wrapper method to check the parameters of RandomSelectSubpolicyOp."""
  531. @wraps(method)
  532. def new_method(self, *args, **kwargs):
  533. [policy], _ = parse_user_args(method, *args, **kwargs)
  534. type_check(policy, (list,), "policy")
  535. if not policy:
  536. raise ValueError("policy can not be empty.")
  537. for sub_ind, sub in enumerate(policy):
  538. type_check(sub, (list,), "policy[{0}]".format([sub_ind]))
  539. if not sub:
  540. raise ValueError("policy[{0}] can not be empty.".format(sub_ind))
  541. for op_ind, tp in enumerate(sub):
  542. check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind))
  543. check_c_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
  544. check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind))
  545. return method(self, *args, **kwargs)
  546. return new_method
  547. def check_soft_dvpp_decode_random_crop_resize_jpeg(method):
  548. """Wrapper method to check the parameters of SoftDvppDecodeRandomCropResizeJpeg."""
  549. @wraps(method)
  550. def new_method(self, *args, **kwargs):
  551. [size, scale, ratio, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  552. check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
  553. return method(self, *args, **kwargs)
  554. return new_method
  555. def check_random_solarize(method):
  556. """Wrapper method to check the parameters of RandomSolarizeOp."""
  557. @wraps(method)
  558. def new_method(self, *args, **kwargs):
  559. [threshold], _ = parse_user_args(method, *args, **kwargs)
  560. type_check(threshold, (tuple,), "threshold")
  561. type_check_list(threshold, (int,), "threshold")
  562. if len(threshold) != 2:
  563. raise ValueError("threshold must be a sequence of two numbers.")
  564. for element in threshold:
  565. check_value(element, (0, UINT8_MAX))
  566. if threshold[1] < threshold[0]:
  567. raise ValueError("threshold must be in min max format numbers.")
  568. return method(self, *args, **kwargs)
  569. return new_method