You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 39 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. import numbers
  18. from functools import wraps
  19. import numpy as np
  20. from mindspore._c_dataengine import TensorOp, TensorOperation
  21. from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER, \
  22. check_pos_float32, check_float32, check_2tuple, check_range, check_positive, INT32_MAX, INT32_MIN, \
  23. parse_user_args, type_check, type_check_list, check_c_tensor_op, UINT8_MAX, check_value_normalize_std, \
  24. check_value_cutoff, check_value_ratio, check_odd, check_non_negative_float32
  25. from .utils import Inter, Border, ImageBatchFormat, ConvertMode, SliceMode, AutoAugmentPolicy
  26. def check_crop_size(size):
  27. """Wrapper method to check the parameters of crop size."""
  28. type_check(size, (int, list, tuple), "size")
  29. if isinstance(size, int):
  30. check_value(size, (1, FLOAT_MAX_INTEGER))
  31. elif isinstance(size, (tuple, list)) and len(size) == 2:
  32. for index, value in enumerate(size):
  33. type_check(value, (int,), "size[{}]".format(index))
  34. check_value(value, (1, FLOAT_MAX_INTEGER))
  35. else:
  36. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  37. def check_crop_coordinates(coordinates):
  38. """Wrapper method to check the parameters of crop size."""
  39. type_check(coordinates, (list, tuple), "coordinates")
  40. if isinstance(coordinates, (tuple, list)) and len(coordinates) == 2:
  41. for index, value in enumerate(coordinates):
  42. type_check(value, (int,), "coordinates[{}]".format(index))
  43. check_value(value, (0, INT32_MAX), "coordinates[{}]".format(index))
  44. else:
  45. raise TypeError("Coordinates should be a list/tuple (y, x) of length 2.")
  46. def check_cut_mix_batch_c(method):
  47. """Wrapper method to check the parameters of CutMixBatch."""
  48. @wraps(method)
  49. def new_method(self, *args, **kwargs):
  50. [image_batch_format, alpha, prob], _ = parse_user_args(method, *args, **kwargs)
  51. type_check(image_batch_format, (ImageBatchFormat,), "image_batch_format")
  52. type_check(alpha, (int, float), "alpha")
  53. type_check(prob, (int, float), "prob")
  54. check_pos_float32(alpha)
  55. check_positive(alpha, "alpha")
  56. check_value(prob, [0, 1], "prob")
  57. return method(self, *args, **kwargs)
  58. return new_method
  59. def check_resize_size(size):
  60. """Wrapper method to check the parameters of resize."""
  61. if isinstance(size, int):
  62. check_value(size, (1, FLOAT_MAX_INTEGER))
  63. elif isinstance(size, (tuple, list)) and len(size) == 2:
  64. for i, value in enumerate(size):
  65. type_check(value, (int,), "size at dim {0}".format(i))
  66. check_value(value, (1, INT32_MAX), "size at dim {0}".format(i))
  67. else:
  68. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  69. def check_mix_up_batch_c(method):
  70. """Wrapper method to check the parameters of MixUpBatch."""
  71. @wraps(method)
  72. def new_method(self, *args, **kwargs):
  73. [alpha], _ = parse_user_args(method, *args, **kwargs)
  74. type_check(alpha, (int, float), "alpha")
  75. check_positive(alpha, "alpha")
  76. check_pos_float32(alpha)
  77. return method(self, *args, **kwargs)
  78. return new_method
  79. def check_normalize_c_param(mean, std):
  80. type_check(mean, (list, tuple), "mean")
  81. type_check(std, (list, tuple), "std")
  82. if len(mean) != len(std):
  83. raise ValueError("Length of mean and std must be equal.")
  84. for mean_value in mean:
  85. check_value(mean_value, [0, 255], "mean_value")
  86. for std_value in std:
  87. check_value_normalize_std(std_value, [0, 255], "std_value")
  88. def check_normalize_py_param(mean, std):
  89. type_check(mean, (list, tuple), "mean")
  90. type_check(std, (list, tuple), "std")
  91. if len(mean) != len(std):
  92. raise ValueError("Length of mean and std must be equal.")
  93. for mean_value in mean:
  94. check_value(mean_value, [0., 1.], "mean_value")
  95. for std_value in std:
  96. check_value_normalize_std(std_value, [0., 1.], "std_value")
  97. def check_fill_value(fill_value):
  98. if isinstance(fill_value, int):
  99. check_uint8(fill_value)
  100. elif isinstance(fill_value, tuple) and len(fill_value) == 3:
  101. for value in fill_value:
  102. check_uint8(value)
  103. else:
  104. raise TypeError("fill_value should be a single integer or a 3-tuple.")
  105. def check_padding(padding):
  106. """Parsing the padding arguments and check if it is legal."""
  107. type_check(padding, (tuple, list, numbers.Number), "padding")
  108. if isinstance(padding, numbers.Number):
  109. check_value(padding, (0, INT32_MAX), "padding")
  110. if isinstance(padding, (tuple, list)):
  111. if len(padding) not in (2, 4):
  112. raise ValueError("The size of the padding list or tuple should be 2 or 4.")
  113. for i, pad_value in enumerate(padding):
  114. type_check(pad_value, (int,), "padding[{}]".format(i))
  115. check_value(pad_value, (0, INT32_MAX), "pad_value")
  116. def check_degrees(degrees):
  117. """Check if the `degrees` is legal."""
  118. type_check(degrees, (int, float, list, tuple), "degrees")
  119. if isinstance(degrees, (int, float)):
  120. check_non_negative_float32(degrees, "degrees")
  121. elif isinstance(degrees, (list, tuple)):
  122. if len(degrees) == 2:
  123. type_check_list(degrees, (int, float), "degrees")
  124. for value in degrees:
  125. check_float32(value, "degrees")
  126. if degrees[0] > degrees[1]:
  127. raise ValueError("degrees should be in (min,max) format. Got (max,min).")
  128. else:
  129. raise TypeError("If degrees is a sequence, the length must be 2.")
  130. def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
  131. """Check the parameters in random color adjust operation."""
  132. type_check(value, (numbers.Number, list, tuple), input_name)
  133. if isinstance(value, numbers.Number):
  134. if value < 0:
  135. raise ValueError("The input value of {} cannot be negative.".format(input_name))
  136. elif isinstance(value, (list, tuple)):
  137. if len(value) != 2:
  138. raise TypeError("If {0} is a sequence, the length must be 2.".format(input_name))
  139. if value[0] > value[1]:
  140. raise ValueError("{0} value should be in (min,max) format. Got ({1}, {2}).".format(input_name,
  141. value[0], value[1]))
  142. check_range(value, bound)
  143. def check_erasing_value(value):
  144. if not (isinstance(value, (numbers.Number,)) or
  145. (isinstance(value, (str,)) and value == 'random') or
  146. (isinstance(value, (tuple, list)) and len(value) == 3)):
  147. raise ValueError("The value for erasing should be either a single value, "
  148. "or a string 'random', or a sequence of 3 elements for RGB respectively.")
  149. def check_crop(method):
  150. """A wrapper that wraps a parameter checker around the original function(crop operation)."""
  151. @wraps(method)
  152. def new_method(self, *args, **kwargs):
  153. [coordinates, size], _ = parse_user_args(method, *args, **kwargs)
  154. check_crop_coordinates(coordinates)
  155. check_crop_size(size)
  156. return method(self, *args, **kwargs)
  157. return new_method
  158. def check_center_crop(method):
  159. """A wrapper that wraps a parameter checker around the original function(center crop operation)."""
  160. @wraps(method)
  161. def new_method(self, *args, **kwargs):
  162. [size], _ = parse_user_args(method, *args, **kwargs)
  163. check_crop_size(size)
  164. return method(self, *args, **kwargs)
  165. return new_method
  166. def check_five_crop(method):
  167. """A wrapper that wraps a parameter checker around the original function(five crop operation)."""
  168. @wraps(method)
  169. def new_method(self, *args, **kwargs):
  170. [size], _ = parse_user_args(method, *args, **kwargs)
  171. check_crop_size(size)
  172. return method(self, *args, **kwargs)
  173. return new_method
  174. def check_posterize(method):
  175. """A wrapper that wraps a parameter checker around the original function(posterize operation)."""
  176. @wraps(method)
  177. def new_method(self, *args, **kwargs):
  178. [bits], _ = parse_user_args(method, *args, **kwargs)
  179. if bits is not None:
  180. type_check(bits, (list, tuple, int), "bits")
  181. if isinstance(bits, int):
  182. check_value(bits, [1, 8])
  183. if isinstance(bits, (list, tuple)):
  184. if len(bits) != 2:
  185. raise TypeError("Size of bits should be a single integer or a list/tuple (min, max) of length 2.")
  186. for item in bits:
  187. check_uint8(item, "bits")
  188. # also checks if min <= max
  189. check_range(bits, [1, 8])
  190. return method(self, *args, **kwargs)
  191. return new_method
  192. def check_resize_interpolation(method):
  193. """A wrapper that wraps a parameter checker around the original function(resize interpolation operation)."""
  194. @wraps(method)
  195. def new_method(self, *args, **kwargs):
  196. [size, interpolation], _ = parse_user_args(method, *args, **kwargs)
  197. if interpolation is None:
  198. raise KeyError("Interpolation should not be None")
  199. check_resize_size(size)
  200. type_check(interpolation, (Inter,), "interpolation")
  201. return method(self, *args, **kwargs)
  202. return new_method
  203. def check_resize(method):
  204. """A wrapper that wraps a parameter checker around the original function(resize operation)."""
  205. @wraps(method)
  206. def new_method(self, *args, **kwargs):
  207. [size], _ = parse_user_args(method, *args, **kwargs)
  208. check_resize_size(size)
  209. return method(self, *args, **kwargs)
  210. return new_method
  211. def check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts):
  212. """Wrapper method to check the parameters of RandomCropDecodeResize and SoftDvppDecodeRandomCropResizeJpeg."""
  213. check_crop_size(size)
  214. if scale is not None:
  215. type_check(scale, (tuple, list), "scale")
  216. if len(scale) != 2:
  217. raise TypeError("scale should be a list/tuple of length 2.")
  218. type_check_list(scale, (float, int), "scale")
  219. if scale[0] > scale[1]:
  220. raise ValueError("scale should be in (min,max) format. Got (max,min).")
  221. check_range(scale, [0, FLOAT_MAX_INTEGER])
  222. check_positive(scale[1], "scale[1]")
  223. if ratio is not None:
  224. type_check(ratio, (tuple, list), "ratio")
  225. if len(ratio) != 2:
  226. raise TypeError("ratio should be a list/tuple of length 2.")
  227. type_check_list(ratio, (float, int), "ratio")
  228. if ratio[0] > ratio[1]:
  229. raise ValueError("ratio should be in (min,max) format. Got (max,min).")
  230. check_range(ratio, [0, FLOAT_MAX_INTEGER])
  231. check_positive(ratio[0], "ratio[0]")
  232. check_positive(ratio[1], "ratio[1]")
  233. if max_attempts is not None:
  234. check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
  235. def check_random_adjust_sharpness(method):
  236. """Wrapper method to check the parameters of RandomAdjustSharpness."""
  237. @wraps(method)
  238. def new_method(self, *args, **kwargs):
  239. [degree, prob], _ = parse_user_args(method, *args, **kwargs)
  240. type_check(degree, (float, int), "degree")
  241. check_non_negative_float32(degree, "degree")
  242. type_check(prob, (float, int), "prob")
  243. check_value(prob, [0., 1.], "prob")
  244. return method(self, *args, **kwargs)
  245. return new_method
  246. def check_random_resize_crop(method):
  247. """A wrapper that wraps a parameter checker around the original function(random resize crop operation)."""
  248. @wraps(method)
  249. def new_method(self, *args, **kwargs):
  250. [size, scale, ratio, interpolation, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  251. if interpolation is not None:
  252. type_check(interpolation, (Inter,), "interpolation")
  253. check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
  254. return method(self, *args, **kwargs)
  255. return new_method
  256. def check_random_auto_contrast(method):
  257. """Wrapper method to check the parameters of Python RandomAutoContrast op."""
  258. @wraps(method)
  259. def new_method(self, *args, **kwargs):
  260. [cutoff, ignore, prob], _ = parse_user_args(method, *args, **kwargs)
  261. type_check(cutoff, (int, float), "cutoff")
  262. check_value_cutoff(cutoff, [0, 50], "cutoff")
  263. if ignore is not None:
  264. type_check(ignore, (list, tuple, int), "ignore")
  265. if isinstance(ignore, int):
  266. check_value(ignore, [0, 255], "ignore")
  267. if isinstance(ignore, (list, tuple)):
  268. for item in ignore:
  269. type_check(item, (int,), "item")
  270. check_value(item, [0, 255], "ignore")
  271. type_check(prob, (float, int,), "prob")
  272. check_value(prob, [0., 1.], "prob")
  273. return method(self, *args, **kwargs)
  274. return new_method
  275. def check_prob(method):
  276. """A wrapper that wraps a parameter checker (to confirm probability) around the original function."""
  277. @wraps(method)
  278. def new_method(self, *args, **kwargs):
  279. [prob], _ = parse_user_args(method, *args, **kwargs)
  280. type_check(prob, (float, int,), "prob")
  281. check_value(prob, [0., 1.], "prob")
  282. return method(self, *args, **kwargs)
  283. return new_method
  284. def check_alpha(method):
  285. """A wrapper method to check alpha parameter in RandomLighting."""
  286. @wraps(method)
  287. def new_method(self, *args, **kwargs):
  288. [alpha], _ = parse_user_args(method, *args, **kwargs)
  289. type_check(alpha, (float, int,), "alpha")
  290. check_non_negative_float32(alpha, "alpha")
  291. return method(self, *args, **kwargs)
  292. return new_method
  293. def check_normalize_c(method):
  294. """A wrapper that wraps a parameter checker around the original function(normalize operation written in C++)."""
  295. @wraps(method)
  296. def new_method(self, *args, **kwargs):
  297. [mean, std], _ = parse_user_args(method, *args, **kwargs)
  298. check_normalize_c_param(mean, std)
  299. return method(self, *args, **kwargs)
  300. return new_method
  301. def check_normalize_py(method):
  302. """A wrapper that wraps a parameter checker around the original function(normalize operation written in Python)."""
  303. @wraps(method)
  304. def new_method(self, *args, **kwargs):
  305. [mean, std], _ = parse_user_args(method, *args, **kwargs)
  306. check_normalize_py_param(mean, std)
  307. return method(self, *args, **kwargs)
  308. return new_method
  309. def check_normalizepad_c(method):
  310. """A wrapper that wraps a parameter checker around the original function(normalizepad written in C++)."""
  311. @wraps(method)
  312. def new_method(self, *args, **kwargs):
  313. [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
  314. check_normalize_c_param(mean, std)
  315. if not isinstance(dtype, str):
  316. raise TypeError("dtype should be string.")
  317. if dtype not in ["float32", "float16"]:
  318. raise ValueError("dtype only support float32 or float16.")
  319. return method(self, *args, **kwargs)
  320. return new_method
  321. def check_normalizepad_py(method):
  322. """A wrapper that wraps a parameter checker around the original function(normalizepad written in Python)."""
  323. @wraps(method)
  324. def new_method(self, *args, **kwargs):
  325. [mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
  326. check_normalize_py_param(mean, std)
  327. if not isinstance(dtype, str):
  328. raise TypeError("dtype should be string.")
  329. if dtype not in ["float32", "float16"]:
  330. raise ValueError("dtype only support float32 or float16.")
  331. return method(self, *args, **kwargs)
  332. return new_method
  333. def check_random_crop(method):
  334. """Wrapper method to check the parameters of random crop."""
  335. @wraps(method)
  336. def new_method(self, *args, **kwargs):
  337. [size, padding, pad_if_needed, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
  338. check_crop_size(size)
  339. type_check(pad_if_needed, (bool,), "pad_if_needed")
  340. if padding is not None:
  341. check_padding(padding)
  342. if fill_value is not None:
  343. check_fill_value(fill_value)
  344. if padding_mode is not None:
  345. type_check(padding_mode, (Border,), "padding_mode")
  346. return method(self, *args, **kwargs)
  347. return new_method
  348. def check_random_color_adjust(method):
  349. """Wrapper method to check the parameters of random color adjust."""
  350. @wraps(method)
  351. def new_method(self, *args, **kwargs):
  352. [brightness, contrast, saturation, hue], _ = parse_user_args(method, *args, **kwargs)
  353. check_random_color_adjust_param(brightness, "brightness")
  354. check_random_color_adjust_param(contrast, "contrast")
  355. check_random_color_adjust_param(saturation, "saturation")
  356. check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  357. return method(self, *args, **kwargs)
  358. return new_method
  359. def check_resample_expand_center_fill_value_params(resample, expand, center, fill_value):
  360. type_check(resample, (Inter,), "resample")
  361. type_check(expand, (bool,), "expand")
  362. if center is not None:
  363. check_2tuple(center, "center")
  364. for value in center:
  365. type_check(value, (int, float), "center")
  366. check_value(value, [INT32_MIN, INT32_MAX], "center")
  367. check_fill_value(fill_value)
  368. def check_random_rotation(method):
  369. """Wrapper method to check the parameters of random rotation."""
  370. @wraps(method)
  371. def new_method(self, *args, **kwargs):
  372. [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
  373. check_degrees(degrees)
  374. check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
  375. return method(self, *args, **kwargs)
  376. return new_method
  377. def check_rotate(method):
  378. """Wrapper method to check the parameters of rotate."""
  379. @wraps(method)
  380. def new_method(self, *args, **kwargs):
  381. [degrees, resample, expand, center, fill_value], _ = parse_user_args(method, *args, **kwargs)
  382. type_check(degrees, (float, int), "degrees")
  383. check_float32(degrees, "degrees")
  384. check_resample_expand_center_fill_value_params(resample, expand, center, fill_value)
  385. return method(self, *args, **kwargs)
  386. return new_method
  387. def check_ten_crop(method):
  388. """Wrapper method to check the parameters of crop."""
  389. @wraps(method)
  390. def new_method(self, *args, **kwargs):
  391. [size, use_vertical_flip], _ = parse_user_args(method, *args, **kwargs)
  392. check_crop_size(size)
  393. if use_vertical_flip is not None:
  394. type_check(use_vertical_flip, (bool,), "use_vertical_flip")
  395. return method(self, *args, **kwargs)
  396. return new_method
  397. def check_num_channels(method):
  398. """Wrapper method to check the parameters of number of channels."""
  399. @wraps(method)
  400. def new_method(self, *args, **kwargs):
  401. [num_output_channels], _ = parse_user_args(method, *args, **kwargs)
  402. if num_output_channels is not None:
  403. if num_output_channels not in (1, 3):
  404. raise ValueError("Number of channels of the output grayscale image"
  405. "should be either 1 or 3. Got {0}.".format(num_output_channels))
  406. return method(self, *args, **kwargs)
  407. return new_method
  408. def check_pad(method):
  409. """Wrapper method to check the parameters of random pad."""
  410. @wraps(method)
  411. def new_method(self, *args, **kwargs):
  412. [padding, fill_value, padding_mode], _ = parse_user_args(method, *args, **kwargs)
  413. check_padding(padding)
  414. check_fill_value(fill_value)
  415. type_check(padding_mode, (Border,), "padding_mode")
  416. return method(self, *args, **kwargs)
  417. return new_method
  418. def check_slice_patches(method):
  419. """Wrapper method to check the parameters of slice patches."""
  420. @wraps(method)
  421. def new_method(self, *args, **kwargs):
  422. [num_height, num_width, slice_mode, fill_value], _ = parse_user_args(method, *args, **kwargs)
  423. if num_height is not None:
  424. type_check(num_height, (int,), "num_height")
  425. check_value(num_height, (1, INT32_MAX), "num_height")
  426. if num_width is not None:
  427. type_check(num_width, (int,), "num_width")
  428. check_value(num_width, (1, INT32_MAX), "num_width")
  429. if slice_mode is not None:
  430. type_check(slice_mode, (SliceMode,), "slice_mode")
  431. if fill_value is not None:
  432. type_check(fill_value, (int,), "fill_value")
  433. check_value(fill_value, [0, 255], "fill_value")
  434. return method(self, *args, **kwargs)
  435. return new_method
  436. def check_random_perspective(method):
  437. """Wrapper method to check the parameters of random perspective."""
  438. @wraps(method)
  439. def new_method(self, *args, **kwargs):
  440. [distortion_scale, prob, interpolation], _ = parse_user_args(method, *args, **kwargs)
  441. type_check(distortion_scale, (float,), "distortion_scale")
  442. type_check(prob, (float,), "prob")
  443. check_value(distortion_scale, [0., 1.], "distortion_scale")
  444. check_value(prob, [0., 1.], "prob")
  445. type_check(interpolation, (Inter,), "interpolation")
  446. return method(self, *args, **kwargs)
  447. return new_method
  448. def check_mix_up(method):
  449. """Wrapper method to check the parameters of mix up."""
  450. @wraps(method)
  451. def new_method(self, *args, **kwargs):
  452. [batch_size, alpha, is_single], _ = parse_user_args(method, *args, **kwargs)
  453. type_check(is_single, (bool,), "is_single")
  454. type_check(batch_size, (int,), "batch_size")
  455. type_check(alpha, (int, float), "alpha")
  456. check_value(batch_size, (1, FLOAT_MAX_INTEGER))
  457. check_positive(alpha, "alpha")
  458. return method(self, *args, **kwargs)
  459. return new_method
  460. def check_rgb_to_bgr(method):
  461. """Wrapper method to check the parameters of rgb_to_bgr."""
  462. @wraps(method)
  463. def new_method(self, *args, **kwargs):
  464. [is_hwc], _ = parse_user_args(method, *args, **kwargs)
  465. type_check(is_hwc, (bool,), "is_hwc")
  466. return method(self, *args, **kwargs)
  467. return new_method
  468. def check_rgb_to_hsv(method):
  469. """Wrapper method to check the parameters of rgb_to_hsv."""
  470. @wraps(method)
  471. def new_method(self, *args, **kwargs):
  472. [is_hwc], _ = parse_user_args(method, *args, **kwargs)
  473. type_check(is_hwc, (bool,), "is_hwc")
  474. return method(self, *args, **kwargs)
  475. return new_method
  476. def check_hsv_to_rgb(method):
  477. """Wrapper method to check the parameters of hsv_to_rgb."""
  478. @wraps(method)
  479. def new_method(self, *args, **kwargs):
  480. [is_hwc], _ = parse_user_args(method, *args, **kwargs)
  481. type_check(is_hwc, (bool,), "is_hwc")
  482. return method(self, *args, **kwargs)
  483. return new_method
  484. def check_random_erasing(method):
  485. """Wrapper method to check the parameters of random erasing."""
  486. @wraps(method)
  487. def new_method(self, *args, **kwargs):
  488. [prob, scale, ratio, value, inplace, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  489. type_check(prob, (float, int,), "prob")
  490. type_check_list(scale, (float, int,), "scale")
  491. if len(scale) != 2:
  492. raise TypeError("scale should be a list or tuple of length 2.")
  493. type_check_list(ratio, (float, int,), "ratio")
  494. if len(ratio) != 2:
  495. raise TypeError("ratio should be a list or tuple of length 2.")
  496. type_check(value, (int, list, tuple, str), "value")
  497. type_check(inplace, (bool,), "inplace")
  498. type_check(max_attempts, (int,), "max_attempts")
  499. check_erasing_value(value)
  500. check_value(prob, [0., 1.], "prob")
  501. if scale[0] > scale[1]:
  502. raise ValueError("scale should be in (min,max) format. Got (max,min).")
  503. check_range(scale, [0, FLOAT_MAX_INTEGER])
  504. check_positive(scale[1], "scale[1]")
  505. if ratio[0] > ratio[1]:
  506. raise ValueError("ratio should be in (min,max) format. Got (max,min).")
  507. check_value_ratio(ratio[0], [0, FLOAT_MAX_INTEGER])
  508. check_value_ratio(ratio[1], [0, FLOAT_MAX_INTEGER])
  509. if isinstance(value, int):
  510. check_value(value, (0, 255))
  511. if isinstance(value, (list, tuple)):
  512. for item in value:
  513. type_check(item, (int,), "value")
  514. check_value(item, [0, 255], "value")
  515. check_value(max_attempts, (1, FLOAT_MAX_INTEGER))
  516. return method(self, *args, **kwargs)
  517. return new_method
  518. def check_cutout(method):
  519. """Wrapper method to check the parameters of cutout operation."""
  520. @wraps(method)
  521. def new_method(self, *args, **kwargs):
  522. [length, num_patches], _ = parse_user_args(method, *args, **kwargs)
  523. type_check(length, (int,), "length")
  524. type_check(num_patches, (int,), "num_patches")
  525. check_value(length, (1, FLOAT_MAX_INTEGER))
  526. check_value(num_patches, (1, FLOAT_MAX_INTEGER))
  527. return method(self, *args, **kwargs)
  528. return new_method
  529. def check_linear_transform(method):
  530. """Wrapper method to check the parameters of linear transform."""
  531. @wraps(method)
  532. def new_method(self, *args, **kwargs):
  533. [transformation_matrix, mean_vector], _ = parse_user_args(method, *args, **kwargs)
  534. type_check(transformation_matrix, (np.ndarray,), "transformation_matrix")
  535. type_check(mean_vector, (np.ndarray,), "mean_vector")
  536. if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
  537. raise ValueError("transformation_matrix should be a square matrix. "
  538. "Got shape {} instead.".format(transformation_matrix.shape))
  539. if mean_vector.shape[0] != transformation_matrix.shape[0]:
  540. raise ValueError("mean_vector length {0} should match either one dimension of the square"
  541. "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
  542. return method(self, *args, **kwargs)
  543. return new_method
  544. def check_random_affine(method):
  545. """Wrapper method to check the parameters of random affine."""
  546. @wraps(method)
  547. def new_method(self, *args, **kwargs):
  548. [degrees, translate, scale, shear, resample, fill_value], _ = parse_user_args(method, *args, **kwargs)
  549. check_degrees(degrees)
  550. if translate is not None:
  551. type_check(translate, (list, tuple), "translate")
  552. type_check_list(translate, (int, float), "translate")
  553. if len(translate) != 2 and len(translate) != 4:
  554. raise TypeError("translate should be a list or tuple of length 2 or 4.")
  555. for i, t in enumerate(translate):
  556. check_value(t, [-1.0, 1.0], "translate at {0}".format(i))
  557. if scale is not None:
  558. type_check(scale, (tuple, list), "scale")
  559. type_check_list(scale, (int, float), "scale")
  560. if len(scale) == 2:
  561. if scale[0] > scale[1]:
  562. raise ValueError("Input scale[1] must be equal to or greater than scale[0].")
  563. check_range(scale, [0, FLOAT_MAX_INTEGER])
  564. check_positive(scale[1], "scale[1]")
  565. else:
  566. raise TypeError("scale should be a list or tuple of length 2.")
  567. if shear is not None:
  568. type_check(shear, (numbers.Number, tuple, list), "shear")
  569. if isinstance(shear, numbers.Number):
  570. check_positive(shear, "shear")
  571. else:
  572. type_check_list(shear, (int, float), "shear")
  573. if len(shear) not in (2, 4):
  574. raise TypeError("shear must be of length 2 or 4.")
  575. if len(shear) == 2 and shear[0] > shear[1]:
  576. raise ValueError("Input shear[1] must be equal to or greater than shear[0]")
  577. if len(shear) == 4 and (shear[0] > shear[1] or shear[2] > shear[3]):
  578. raise ValueError("Input shear[1] must be equal to or greater than shear[0] and "
  579. "shear[3] must be equal to or greater than shear[2].")
  580. type_check(resample, (Inter,), "resample")
  581. if fill_value is not None:
  582. check_fill_value(fill_value)
  583. return method(self, *args, **kwargs)
  584. return new_method
  585. def check_rescale(method):
  586. """Wrapper method to check the parameters of rescale."""
  587. @wraps(method)
  588. def new_method(self, *args, **kwargs):
  589. [rescale, shift], _ = parse_user_args(method, *args, **kwargs)
  590. type_check(rescale, (numbers.Number,), "rescale")
  591. type_check(shift, (numbers.Number,), "shift")
  592. check_float32(rescale)
  593. check_float32(shift)
  594. return method(self, *args, **kwargs)
  595. return new_method
  596. def check_uniform_augment_cpp(method):
  597. """Wrapper method to check the parameters of UniformAugment C++ op."""
  598. @wraps(method)
  599. def new_method(self, *args, **kwargs):
  600. [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
  601. type_check(num_ops, (int,), "num_ops")
  602. check_positive(num_ops, "num_ops")
  603. if num_ops > len(transforms):
  604. raise ValueError("num_ops is greater than transforms list size.")
  605. parsed_transforms = []
  606. for op in transforms:
  607. if op and getattr(op, 'parse', None):
  608. parsed_transforms.append(op.parse())
  609. else:
  610. parsed_transforms.append(op)
  611. type_check(parsed_transforms, (list, tuple,), "transforms")
  612. for index, arg in enumerate(parsed_transforms):
  613. if not isinstance(arg, (TensorOp, TensorOperation)):
  614. raise TypeError("Type of Transforms[{0}] must be c_transform, but got {1}".format(index, type(arg)))
  615. return method(self, *args, **kwargs)
  616. return new_method
  617. def check_bounding_box_augment_cpp(method):
  618. """Wrapper method to check the parameters of BoundingBoxAugment C++ op."""
  619. @wraps(method)
  620. def new_method(self, *args, **kwargs):
  621. [transform, ratio], _ = parse_user_args(method, *args, **kwargs)
  622. type_check(ratio, (float, int), "ratio")
  623. check_value(ratio, [0., 1.], "ratio")
  624. if transform and getattr(transform, 'parse', None):
  625. transform = transform.parse()
  626. type_check(transform, (TensorOp, TensorOperation), "transform")
  627. return method(self, *args, **kwargs)
  628. return new_method
  629. def check_adjust_gamma(method):
  630. """Wrapper method to check the parameters of AdjustGamma ops (Python and C++)."""
  631. @wraps(method)
  632. def new_method(self, *args, **kwargs):
  633. [gamma, gain], _ = parse_user_args(method, *args, **kwargs)
  634. type_check(gamma, (float, int), "gamma")
  635. check_value(gamma, (0, FLOAT_MAX_INTEGER))
  636. if gain is not None:
  637. type_check(gain, (float, int), "gain")
  638. check_value(gain, (FLOAT_MIN_INTEGER, FLOAT_MAX_INTEGER))
  639. return method(self, *args, **kwargs)
  640. return new_method
  641. def check_auto_contrast(method):
  642. """Wrapper method to check the parameters of AutoContrast ops (Python and C++)."""
  643. @wraps(method)
  644. def new_method(self, *args, **kwargs):
  645. [cutoff, ignore], _ = parse_user_args(method, *args, **kwargs)
  646. type_check(cutoff, (int, float), "cutoff")
  647. check_value_cutoff(cutoff, [0, 50], "cutoff")
  648. if ignore is not None:
  649. type_check(ignore, (list, tuple, int), "ignore")
  650. if isinstance(ignore, int):
  651. check_value(ignore, [0, 255], "ignore")
  652. if isinstance(ignore, (list, tuple)):
  653. for item in ignore:
  654. type_check(item, (int,), "item")
  655. check_value(item, [0, 255], "ignore")
  656. return method(self, *args, **kwargs)
  657. return new_method
  658. def check_uniform_augment_py(method):
  659. """Wrapper method to check the parameters of Python UniformAugment op."""
  660. @wraps(method)
  661. def new_method(self, *args, **kwargs):
  662. [transforms, num_ops], _ = parse_user_args(method, *args, **kwargs)
  663. type_check(transforms, (list,), "transforms")
  664. if not transforms:
  665. raise ValueError("transforms list is empty.")
  666. for transform in transforms:
  667. if isinstance(transform, TensorOp):
  668. raise ValueError("transform list only accepts Python operations.")
  669. type_check(num_ops, (int,), "num_ops")
  670. check_positive(num_ops, "num_ops")
  671. if num_ops > len(transforms):
  672. raise ValueError("num_ops cannot be greater than the length of transforms list.")
  673. return method(self, *args, **kwargs)
  674. return new_method
  675. def check_positive_degrees(method):
  676. """A wrapper method to check degrees parameter in RandomSharpness and RandomColor ops (Python and C++)"""
  677. @wraps(method)
  678. def new_method(self, *args, **kwargs):
  679. [degrees], _ = parse_user_args(method, *args, **kwargs)
  680. if degrees is not None:
  681. if not isinstance(degrees, (list, tuple)):
  682. raise TypeError("degrees must be either a tuple or a list.")
  683. type_check_list(degrees, (int, float), "degrees")
  684. if len(degrees) != 2:
  685. raise ValueError("degrees must be a sequence with length 2.")
  686. for degree in degrees:
  687. check_value(degree, (0, FLOAT_MAX_INTEGER))
  688. if degrees[0] > degrees[1]:
  689. raise ValueError("degrees should be in (min,max) format. Got (max,min).")
  690. return method(self, *args, **kwargs)
  691. return new_method
  692. def check_random_select_subpolicy_op(method):
  693. """Wrapper method to check the parameters of RandomSelectSubpolicyOp."""
  694. @wraps(method)
  695. def new_method(self, *args, **kwargs):
  696. [policy], _ = parse_user_args(method, *args, **kwargs)
  697. type_check(policy, (list,), "policy")
  698. if not policy:
  699. raise ValueError("policy can not be empty.")
  700. for sub_ind, sub in enumerate(policy):
  701. type_check(sub, (list,), "policy[{0}]".format([sub_ind]))
  702. if not sub:
  703. raise ValueError("policy[{0}] can not be empty.".format(sub_ind))
  704. for op_ind, tp in enumerate(sub):
  705. check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind))
  706. check_c_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
  707. check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind))
  708. return method(self, *args, **kwargs)
  709. return new_method
  710. def check_soft_dvpp_decode_random_crop_resize_jpeg(method):
  711. """Wrapper method to check the parameters of SoftDvppDecodeRandomCropResizeJpeg."""
  712. @wraps(method)
  713. def new_method(self, *args, **kwargs):
  714. [size, scale, ratio, max_attempts], _ = parse_user_args(method, *args, **kwargs)
  715. check_size_scale_ration_max_attempts_paras(size, scale, ratio, max_attempts)
  716. return method(self, *args, **kwargs)
  717. return new_method
  718. def check_random_solarize(method):
  719. """Wrapper method to check the parameters of RandomSolarizeOp."""
  720. @wraps(method)
  721. def new_method(self, *args, **kwargs):
  722. [threshold], _ = parse_user_args(method, *args, **kwargs)
  723. type_check(threshold, (tuple,), "threshold")
  724. type_check_list(threshold, (int,), "threshold")
  725. if len(threshold) != 2:
  726. raise ValueError("threshold must be a sequence of two numbers.")
  727. for element in threshold:
  728. check_value(element, (0, UINT8_MAX))
  729. if threshold[1] < threshold[0]:
  730. raise ValueError("threshold must be in min max format numbers.")
  731. return method(self, *args, **kwargs)
  732. return new_method
  733. def check_gaussian_blur(method):
  734. """Wrapper method to check the parameters of GaussianBlur."""
  735. @wraps(method)
  736. def new_method(self, *args, **kwargs):
  737. [kernel_size, sigma], _ = parse_user_args(method, *args, **kwargs)
  738. type_check(kernel_size, (int, list, tuple), "kernel_size")
  739. if isinstance(kernel_size, int):
  740. check_value(kernel_size, (1, FLOAT_MAX_INTEGER), "kernel_size")
  741. check_odd(kernel_size, "kernel_size")
  742. elif isinstance(kernel_size, (list, tuple)) and len(kernel_size) == 2:
  743. for index, value in enumerate(kernel_size):
  744. type_check(value, (int,), "kernel_size[{}]".format(index))
  745. check_value(value, (1, FLOAT_MAX_INTEGER), "kernel_size")
  746. check_odd(value, "kernel_size[{}]".format(index))
  747. else:
  748. raise TypeError(
  749. "Kernel size should be a single integer or a list/tuple (kernel_width, kernel_height) of length 2.")
  750. if sigma is not None:
  751. type_check(sigma, (numbers.Number, list, tuple), "sigma")
  752. if isinstance(sigma, numbers.Number):
  753. check_value(sigma, (0, FLOAT_MAX_INTEGER), "sigma")
  754. elif isinstance(sigma, (list, tuple)) and len(sigma) == 2:
  755. for index, value in enumerate(sigma):
  756. type_check(value, (numbers.Number,), "size[{}]".format(index))
  757. check_value(value, (0, FLOAT_MAX_INTEGER), "sigma")
  758. else:
  759. raise TypeError("Sigma should be a single number or a list/tuple of length 2 for width and height.")
  760. return method(self, *args, **kwargs)
  761. return new_method
  762. def check_convert_color(method):
  763. """Wrapper method to check the parameters of convertcolor."""
  764. @wraps(method)
  765. def new_method(self, *args, **kwargs):
  766. [convert_mode], _ = parse_user_args(method, *args, **kwargs)
  767. if convert_mode is not None:
  768. type_check(convert_mode, (ConvertMode,), "convert_mode")
  769. return method(self, *args, **kwargs)
  770. return new_method
  771. def check_auto_augment(method):
  772. """Wrapper method to check the parameters of AutoAugment."""
  773. @wraps(method)
  774. def new_method(self, *args, **kwargs):
  775. [policy, interpolation, fill_value], _ = parse_user_args(method, *args, **kwargs)
  776. type_check(policy, (AutoAugmentPolicy,), "policy")
  777. type_check(interpolation, (Inter,), "interpolation")
  778. check_fill_value(fill_value)
  779. return method(self, *args, **kwargs)
  780. return new_method