You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validators.py 33 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Validators for TensorOps.
  16. """
  17. import numbers
  18. from functools import wraps
  19. from mindspore._c_dataengine import TensorOp
  20. from .utils import Inter, Border
  21. from ...transforms.validators import check_pos_int32, check_pos_float32, check_value, check_uint8, FLOAT_MAX_INTEGER, \
  22. check_bool, check_2tuple, check_range, check_list, check_type, check_positive, INT32_MAX
  23. def check_inter_mode(mode):
  24. if not isinstance(mode, Inter):
  25. raise ValueError("Invalid interpolation mode.")
  26. def check_border_type(mode):
  27. if not isinstance(mode, Border):
  28. raise ValueError("Invalid padding mode.")
  29. def check_crop_size(size):
  30. """Wrapper method to check the parameters of crop size."""
  31. if isinstance(size, int):
  32. size = (size, size)
  33. elif isinstance(size, (tuple, list)) and len(size) == 2:
  34. size = size
  35. else:
  36. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  37. for value in size:
  38. check_pos_int32(value)
  39. return size
  40. def check_resize_size(size):
  41. """Wrapper method to check the parameters of resize."""
  42. if isinstance(size, int):
  43. check_pos_int32(size)
  44. elif isinstance(size, (tuple, list)) and len(size) == 2:
  45. for value in size:
  46. check_value(value, (1, INT32_MAX))
  47. else:
  48. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  49. return size
  50. def check_normalize_c_param(mean, std):
  51. if len(mean) != len(std):
  52. raise ValueError("Length of mean and std must be equal")
  53. for mean_value in mean:
  54. check_pos_float32(mean_value)
  55. for std_value in std:
  56. check_pos_float32(std_value)
  57. def check_normalize_py_param(mean, std):
  58. if len(mean) != len(std):
  59. raise ValueError("Length of mean and std must be equal")
  60. for mean_value in mean:
  61. check_value(mean_value, [0., 1.])
  62. for std_value in std:
  63. check_value(std_value, [0., 1.])
  64. def check_fill_value(fill_value):
  65. if isinstance(fill_value, int):
  66. check_uint8(fill_value)
  67. elif isinstance(fill_value, tuple) and len(fill_value) == 3:
  68. for value in fill_value:
  69. check_uint8(value)
  70. else:
  71. raise TypeError("fill_value should be a single integer or a 3-tuple.")
  72. return fill_value
  73. def check_padding(padding):
  74. """Parsing the padding arguments and check if it is legal."""
  75. if isinstance(padding, numbers.Number):
  76. top = bottom = left = right = padding
  77. elif isinstance(padding, (tuple, list)):
  78. if len(padding) == 2:
  79. left = right = padding[0]
  80. top = bottom = padding[1]
  81. elif len(padding) == 4:
  82. left = padding[0]
  83. top = padding[1]
  84. right = padding[2]
  85. bottom = padding[3]
  86. else:
  87. raise ValueError("The size of the padding list or tuple should be 2 or 4.")
  88. else:
  89. raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.")
  90. if not (isinstance(left, int) and isinstance(top, int) and isinstance(right, int) and isinstance(bottom, int)):
  91. raise TypeError("Padding value should be integer.")
  92. if left < 0 or top < 0 or right < 0 or bottom < 0:
  93. raise ValueError("Padding value could not be negative.")
  94. return left, top, right, bottom
  95. def check_degrees(degrees):
  96. """Check if the degrees is legal."""
  97. if isinstance(degrees, numbers.Number):
  98. if degrees < 0:
  99. raise ValueError("If degrees is a single number, it cannot be negative.")
  100. degrees = (-degrees, degrees)
  101. elif isinstance(degrees, (list, tuple)):
  102. if len(degrees) != 2:
  103. raise TypeError("If degrees is a sequence, the length must be 2.")
  104. else:
  105. raise TypeError("Degrees must be a single non-negative number or a sequence")
  106. return degrees
  107. def check_random_color_adjust_param(value, input_name, center=1, bound=(0, FLOAT_MAX_INTEGER), non_negative=True):
  108. """Check the parameters in random color adjust operation."""
  109. if isinstance(value, numbers.Number):
  110. if value < 0:
  111. raise ValueError("The input value of {} cannot be negative.".format(input_name))
  112. # convert value into a range
  113. value = [center - value, center + value]
  114. if non_negative:
  115. value[0] = max(0, value[0])
  116. elif isinstance(value, (list, tuple)) and len(value) == 2:
  117. if not bound[0] <= value[0] <= value[1] <= bound[1]:
  118. raise ValueError("Please check your value range of {} is valid and "
  119. "within the bound {}".format(input_name, bound))
  120. else:
  121. raise TypeError("Input of {} should be either a single value, or a list/tuple of "
  122. "length 2.".format(input_name))
  123. factor = (value[0], value[1])
  124. return factor
  125. def check_erasing_value(value):
  126. if not (isinstance(value, (numbers.Number, str, bytes)) or
  127. (isinstance(value, (tuple, list)) and len(value) == 3)):
  128. raise ValueError("The value for erasing should be either a single value, "
  129. "or a string 'random', or a sequence of 3 elements for RGB respectively.")
  130. def check_crop(method):
  131. """A wrapper that wrap a parameter checker to the original function(crop operation)."""
  132. @wraps(method)
  133. def new_method(self, *args, **kwargs):
  134. size = (list(args) + [None])[0]
  135. if "size" in kwargs:
  136. size = kwargs.get("size")
  137. if size is None:
  138. raise ValueError("size is not provided.")
  139. size = check_crop_size(size)
  140. kwargs["size"] = size
  141. return method(self, **kwargs)
  142. return new_method
  143. def check_resize_interpolation(method):
  144. """A wrapper that wrap a parameter checker to the original function(resize interpolation operation)."""
  145. @wraps(method)
  146. def new_method(self, *args, **kwargs):
  147. args = (list(args) + 2 * [None])[:2]
  148. size, interpolation = args
  149. if "size" in kwargs:
  150. size = kwargs.get("size")
  151. if "interpolation" in kwargs:
  152. interpolation = kwargs.get("interpolation")
  153. if size is None:
  154. raise ValueError("size is not provided.")
  155. size = check_resize_size(size)
  156. kwargs["size"] = size
  157. if interpolation is not None:
  158. check_inter_mode(interpolation)
  159. kwargs["interpolation"] = interpolation
  160. return method(self, **kwargs)
  161. return new_method
  162. def check_resize(method):
  163. """A wrapper that wrap a parameter checker to the original function(resize operation)."""
  164. @wraps(method)
  165. def new_method(self, *args, **kwargs):
  166. size = (list(args) + [None])[0]
  167. if "size" in kwargs:
  168. size = kwargs.get("size")
  169. if size is None:
  170. raise ValueError("size is not provided.")
  171. size = check_resize_size(size)
  172. kwargs["size"] = size
  173. return method(self, **kwargs)
  174. return new_method
  175. def check_random_resize_crop(method):
  176. """A wrapper that wrap a parameter checker to the original function(random resize crop operation)."""
  177. @wraps(method)
  178. def new_method(self, *args, **kwargs):
  179. args = (list(args) + 5 * [None])[:5]
  180. size, scale, ratio, interpolation, max_attempts = args
  181. if "size" in kwargs:
  182. size = kwargs.get("size")
  183. if "scale" in kwargs:
  184. scale = kwargs.get("scale")
  185. if "ratio" in kwargs:
  186. ratio = kwargs.get("ratio")
  187. if "interpolation" in kwargs:
  188. interpolation = kwargs.get("interpolation")
  189. if "max_attempts" in kwargs:
  190. max_attempts = kwargs.get("max_attempts")
  191. if size is None:
  192. raise ValueError("size is not provided.")
  193. size = check_crop_size(size)
  194. kwargs["size"] = size
  195. if scale is not None:
  196. check_range(scale, [0, FLOAT_MAX_INTEGER])
  197. kwargs["scale"] = scale
  198. if ratio is not None:
  199. check_range(ratio, [0, FLOAT_MAX_INTEGER])
  200. check_positive(ratio[0])
  201. kwargs["ratio"] = ratio
  202. if interpolation is not None:
  203. check_inter_mode(interpolation)
  204. kwargs["interpolation"] = interpolation
  205. if max_attempts is not None:
  206. check_pos_int32(max_attempts)
  207. kwargs["max_attempts"] = max_attempts
  208. return method(self, **kwargs)
  209. return new_method
  210. def check_prob(method):
  211. """A wrapper that wrap a parameter checker(check the probability) to the original function."""
  212. @wraps(method)
  213. def new_method(self, *args, **kwargs):
  214. prob = (list(args) + [None])[0]
  215. if "prob" in kwargs:
  216. prob = kwargs.get("prob")
  217. if prob is not None:
  218. check_value(prob, [0., 1.])
  219. kwargs["prob"] = prob
  220. return method(self, **kwargs)
  221. return new_method
  222. def check_normalize_c(method):
  223. """A wrapper that wrap a parameter checker to the original function(normalize operation written in C++)."""
  224. @wraps(method)
  225. def new_method(self, *args, **kwargs):
  226. args = (list(args) + 2 * [None])[:2]
  227. mean, std = args
  228. if "mean" in kwargs:
  229. mean = kwargs.get("mean")
  230. if "std" in kwargs:
  231. std = kwargs.get("std")
  232. if mean is None:
  233. raise ValueError("mean is not provided.")
  234. if std is None:
  235. raise ValueError("std is not provided.")
  236. check_normalize_c_param(mean, std)
  237. kwargs["mean"] = mean
  238. kwargs["std"] = std
  239. return method(self, **kwargs)
  240. return new_method
  241. def check_normalize_py(method):
  242. """A wrapper that wrap a parameter checker to the original function(normalize operation written in Python)."""
  243. @wraps(method)
  244. def new_method(self, *args, **kwargs):
  245. args = (list(args) + 2 * [None])[:2]
  246. mean, std = args
  247. if "mean" in kwargs:
  248. mean = kwargs.get("mean")
  249. if "std" in kwargs:
  250. std = kwargs.get("std")
  251. if mean is None:
  252. raise ValueError("mean is not provided.")
  253. if std is None:
  254. raise ValueError("std is not provided.")
  255. check_normalize_py_param(mean, std)
  256. kwargs["mean"] = mean
  257. kwargs["std"] = std
  258. return method(self, **kwargs)
  259. return new_method
  260. def check_random_crop(method):
  261. """Wrapper method to check the parameters of random crop."""
  262. @wraps(method)
  263. def new_method(self, *args, **kwargs):
  264. args = (list(args) + 5 * [None])[:5]
  265. size, padding, pad_if_needed, fill_value, padding_mode = args
  266. if "size" in kwargs:
  267. size = kwargs.get("size")
  268. if "padding" in kwargs:
  269. padding = kwargs.get("padding")
  270. if "fill_value" in kwargs:
  271. fill_value = kwargs.get("fill_value")
  272. if "padding_mode" in kwargs:
  273. padding_mode = kwargs.get("padding_mode")
  274. if "pad_if_needed" in kwargs:
  275. pad_if_needed = kwargs.get("pad_if_needed")
  276. if size is None:
  277. raise ValueError("size is not provided.")
  278. size = check_crop_size(size)
  279. kwargs["size"] = size
  280. if padding is not None:
  281. padding = check_padding(padding)
  282. kwargs["padding"] = padding
  283. if fill_value is not None:
  284. fill_value = check_fill_value(fill_value)
  285. kwargs["fill_value"] = fill_value
  286. if padding_mode is not None:
  287. check_border_type(padding_mode)
  288. kwargs["padding_mode"] = padding_mode
  289. if pad_if_needed is not None:
  290. kwargs["pad_if_needed"] = pad_if_needed
  291. return method(self, **kwargs)
  292. return new_method
  293. def check_random_color_adjust(method):
  294. """Wrapper method to check the parameters of random color adjust."""
  295. @wraps(method)
  296. def new_method(self, *args, **kwargs):
  297. args = (list(args) + 4 * [None])[:4]
  298. brightness, contrast, saturation, hue = args
  299. if "brightness" in kwargs:
  300. brightness = kwargs.get("brightness")
  301. if "contrast" in kwargs:
  302. contrast = kwargs.get("contrast")
  303. if "saturation" in kwargs:
  304. saturation = kwargs.get("saturation")
  305. if "hue" in kwargs:
  306. hue = kwargs.get("hue")
  307. if brightness is not None:
  308. kwargs["brightness"] = check_random_color_adjust_param(brightness, "brightness")
  309. if contrast is not None:
  310. kwargs["contrast"] = check_random_color_adjust_param(contrast, "contrast")
  311. if saturation is not None:
  312. kwargs["saturation"] = check_random_color_adjust_param(saturation, "saturation")
  313. if hue is not None:
  314. kwargs["hue"] = check_random_color_adjust_param(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  315. return method(self, **kwargs)
  316. return new_method
  317. def check_random_rotation(method):
  318. """Wrapper method to check the parameters of random rotation."""
  319. @wraps(method)
  320. def new_method(self, *args, **kwargs):
  321. args = (list(args) + 5 * [None])[:5]
  322. degrees, resample, expand, center, fill_value = args
  323. if "degrees" in kwargs:
  324. degrees = kwargs.get("degrees")
  325. if "resample" in kwargs:
  326. resample = kwargs.get("resample")
  327. if "expand" in kwargs:
  328. expand = kwargs.get("expand")
  329. if "center" in kwargs:
  330. center = kwargs.get("center")
  331. if "fill_value" in kwargs:
  332. fill_value = kwargs.get("fill_value")
  333. if degrees is None:
  334. raise ValueError("degrees is not provided.")
  335. degrees = check_degrees(degrees)
  336. kwargs["degrees"] = degrees
  337. if resample is not None:
  338. check_inter_mode(resample)
  339. kwargs["resample"] = resample
  340. if expand is not None:
  341. check_bool(expand)
  342. kwargs["expand"] = expand
  343. if center is not None:
  344. check_2tuple(center)
  345. kwargs["center"] = center
  346. if fill_value is not None:
  347. fill_value = check_fill_value(fill_value)
  348. kwargs["fill_value"] = fill_value
  349. return method(self, **kwargs)
  350. return new_method
  351. def check_transforms_list(method):
  352. """Wrapper method to check the parameters of transform list."""
  353. @wraps(method)
  354. def new_method(self, *args, **kwargs):
  355. transforms = (list(args) + [None])[0]
  356. if "transforms" in kwargs:
  357. transforms = kwargs.get("transforms")
  358. if transforms is None:
  359. raise ValueError("transforms is not provided.")
  360. check_list(transforms)
  361. kwargs["transforms"] = transforms
  362. return method(self, **kwargs)
  363. return new_method
  364. def check_random_apply(method):
  365. """Wrapper method to check the parameters of random apply."""
  366. @wraps(method)
  367. def new_method(self, *args, **kwargs):
  368. transforms, prob = (list(args) + 2 * [None])[:2]
  369. if "transforms" in kwargs:
  370. transforms = kwargs.get("transforms")
  371. if transforms is None:
  372. raise ValueError("transforms is not provided.")
  373. check_list(transforms)
  374. kwargs["transforms"] = transforms
  375. if "prob" in kwargs:
  376. prob = kwargs.get("prob")
  377. if prob is not None:
  378. check_value(prob, [0., 1.])
  379. kwargs["prob"] = prob
  380. return method(self, **kwargs)
  381. return new_method
  382. def check_ten_crop(method):
  383. """Wrapper method to check the parameters of crop."""
  384. @wraps(method)
  385. def new_method(self, *args, **kwargs):
  386. args = (list(args) + 2 * [None])[:2]
  387. size, use_vertical_flip = args
  388. if "size" in kwargs:
  389. size = kwargs.get("size")
  390. if "use_vertical_flip" in kwargs:
  391. use_vertical_flip = kwargs.get("use_vertical_flip")
  392. if size is None:
  393. raise ValueError("size is not provided.")
  394. size = check_crop_size(size)
  395. kwargs["size"] = size
  396. if use_vertical_flip is not None:
  397. check_bool(use_vertical_flip)
  398. kwargs["use_vertical_flip"] = use_vertical_flip
  399. return method(self, **kwargs)
  400. return new_method
  401. def check_num_channels(method):
  402. """Wrapper method to check the parameters of number of channels."""
  403. @wraps(method)
  404. def new_method(self, *args, **kwargs):
  405. num_output_channels = (list(args) + [None])[0]
  406. if "num_output_channels" in kwargs:
  407. num_output_channels = kwargs.get("num_output_channels")
  408. if num_output_channels is not None:
  409. if num_output_channels not in (1, 3):
  410. raise ValueError("Number of channels of the output grayscale image"
  411. "should be either 1 or 3. Got {0}".format(num_output_channels))
  412. kwargs["num_output_channels"] = num_output_channels
  413. return method(self, **kwargs)
  414. return new_method
  415. def check_pad(method):
  416. """Wrapper method to check the parameters of random pad."""
  417. @wraps(method)
  418. def new_method(self, *args, **kwargs):
  419. args = (list(args) + 3 * [None])[:3]
  420. padding, fill_value, padding_mode = args
  421. if "padding" in kwargs:
  422. padding = kwargs.get("padding")
  423. if "fill_value" in kwargs:
  424. fill_value = kwargs.get("fill_value")
  425. if "padding_mode" in kwargs:
  426. padding_mode = kwargs.get("padding_mode")
  427. if padding is None:
  428. raise ValueError("padding is not provided.")
  429. padding = check_padding(padding)
  430. kwargs["padding"] = padding
  431. if fill_value is not None:
  432. fill_value = check_fill_value(fill_value)
  433. kwargs["fill_value"] = fill_value
  434. if padding_mode is not None:
  435. check_border_type(padding_mode)
  436. kwargs["padding_mode"] = padding_mode
  437. return method(self, **kwargs)
  438. return new_method
  439. def check_random_perspective(method):
  440. """Wrapper method to check the parameters of random perspective."""
  441. @wraps(method)
  442. def new_method(self, *args, **kwargs):
  443. args = (list(args) + 3 * [None])[:3]
  444. distortion_scale, prob, interpolation = args
  445. if "distortion_scale" in kwargs:
  446. distortion_scale = kwargs.get("distortion_scale")
  447. if "prob" in kwargs:
  448. prob = kwargs.get("prob")
  449. if "interpolation" in kwargs:
  450. interpolation = kwargs.get("interpolation")
  451. if distortion_scale is not None:
  452. check_value(distortion_scale, [0., 1.])
  453. kwargs["distortion_scale"] = distortion_scale
  454. if prob is not None:
  455. check_value(prob, [0., 1.])
  456. kwargs["prob"] = prob
  457. if interpolation is not None:
  458. check_inter_mode(interpolation)
  459. kwargs["interpolation"] = interpolation
  460. return method(self, **kwargs)
  461. return new_method
  462. def check_mix_up(method):
  463. """Wrapper method to check the parameters of mix up."""
  464. @wraps(method)
  465. def new_method(self, *args, **kwargs):
  466. args = (list(args) + 3 * [None])[:3]
  467. batch_size, alpha, is_single = args
  468. if "batch_size" in kwargs:
  469. batch_size = kwargs.get("batch_size")
  470. if "alpha" in kwargs:
  471. alpha = kwargs.get("alpha")
  472. if "is_single" in kwargs:
  473. is_single = kwargs.get("is_single")
  474. if batch_size is None:
  475. raise ValueError("batch_size")
  476. check_pos_int32(batch_size)
  477. kwargs["batch_size"] = batch_size
  478. if alpha is None:
  479. raise ValueError("alpha")
  480. check_positive(alpha)
  481. kwargs["alpha"] = alpha
  482. if is_single is not None:
  483. check_type(is_single, bool)
  484. kwargs["is_single"] = is_single
  485. return method(self, **kwargs)
  486. return new_method
  487. def check_random_erasing(method):
  488. """Wrapper method to check the parameters of random erasing."""
  489. @wraps(method)
  490. def new_method(self, *args, **kwargs):
  491. args = (list(args) + 6 * [None])[:6]
  492. prob, scale, ratio, value, inplace, max_attempts = args
  493. if "prob" in kwargs:
  494. prob = kwargs.get("prob")
  495. if "scale" in kwargs:
  496. scale = kwargs.get("scale")
  497. if "ratio" in kwargs:
  498. ratio = kwargs.get("ratio")
  499. if "value" in kwargs:
  500. value = kwargs.get("value")
  501. if "inplace" in kwargs:
  502. inplace = kwargs.get("inplace")
  503. if "max_attempts" in kwargs:
  504. max_attempts = kwargs.get("max_attempts")
  505. if prob is not None:
  506. check_value(prob, [0., 1.])
  507. kwargs["prob"] = prob
  508. if scale is not None:
  509. check_range(scale, [0, FLOAT_MAX_INTEGER])
  510. kwargs["scale"] = scale
  511. if ratio is not None:
  512. check_range(ratio, [0, FLOAT_MAX_INTEGER])
  513. kwargs["ratio"] = ratio
  514. if value is not None:
  515. check_erasing_value(value)
  516. kwargs["value"] = value
  517. if inplace is not None:
  518. check_bool(inplace)
  519. kwargs["inplace"] = inplace
  520. if max_attempts is not None:
  521. check_pos_int32(max_attempts)
  522. kwargs["max_attempts"] = max_attempts
  523. return method(self, **kwargs)
  524. return new_method
  525. def check_cutout(method):
  526. """Wrapper method to check the parameters of cutout operation."""
  527. @wraps(method)
  528. def new_method(self, *args, **kwargs):
  529. args = (list(args) + 2 * [None])[:2]
  530. length, num_patches = args
  531. if "length" in kwargs:
  532. length = kwargs.get("length")
  533. if "num_patches" in kwargs:
  534. num_patches = kwargs.get("num_patches")
  535. if length is None:
  536. raise ValueError("length")
  537. check_pos_int32(length)
  538. kwargs["length"] = length
  539. if num_patches is not None:
  540. check_pos_int32(num_patches)
  541. kwargs["num_patches"] = num_patches
  542. return method(self, **kwargs)
  543. return new_method
  544. def check_linear_transform(method):
  545. """Wrapper method to check the parameters of linear transform."""
  546. @wraps(method)
  547. def new_method(self, *args, **kwargs):
  548. args = (list(args) + 2 * [None])[:2]
  549. transformation_matrix, mean_vector = args
  550. if "transformation_matrix" in kwargs:
  551. transformation_matrix = kwargs.get("transformation_matrix")
  552. if "mean_vector" in kwargs:
  553. mean_vector = kwargs.get("mean_vector")
  554. if transformation_matrix is None:
  555. raise ValueError("transformation_matrix is not provided.")
  556. if mean_vector is None:
  557. raise ValueError("mean_vector is not provided.")
  558. if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
  559. raise ValueError("transformation_matrix should be a square matrix. "
  560. "Got shape {} instead".format(transformation_matrix.shape))
  561. if mean_vector.shape[0] != transformation_matrix.shape[0]:
  562. raise ValueError("mean_vector length {0} should match either one dimension of the square"
  563. "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
  564. kwargs["transformation_matrix"] = transformation_matrix
  565. kwargs["mean_vector"] = mean_vector
  566. return method(self, **kwargs)
  567. return new_method
  568. def check_random_affine(method):
  569. """Wrapper method to check the parameters of random affine."""
  570. @wraps(method)
  571. def new_method(self, *args, **kwargs):
  572. args = (list(args) + 6 * [None])[:6]
  573. degrees, translate, scale, shear, resample, fill_value = args
  574. if "degrees" in kwargs:
  575. degrees = kwargs.get("degrees")
  576. if "translate" in kwargs:
  577. translate = kwargs.get("translate")
  578. if "scale" in kwargs:
  579. scale = kwargs.get("scale")
  580. if "shear" in kwargs:
  581. shear = kwargs.get("shear")
  582. if "resample" in kwargs:
  583. resample = kwargs.get("resample")
  584. if "fill_value" in kwargs:
  585. fill_value = kwargs.get("fill_value")
  586. if degrees is None:
  587. raise ValueError("degrees is not provided.")
  588. degrees = check_degrees(degrees)
  589. kwargs["degrees"] = degrees
  590. if translate is not None:
  591. if isinstance(translate, (tuple, list)) and len(translate) == 2:
  592. for t in translate:
  593. if t < 0.0 or t > 1.0:
  594. raise ValueError("translation values should be between 0 and 1")
  595. else:
  596. raise TypeError("translate should be a list or tuple of length 2.")
  597. kwargs["translate"] = translate
  598. if scale is not None:
  599. if isinstance(scale, (tuple, list)) and len(scale) == 2:
  600. for s in scale:
  601. if s <= 0:
  602. raise ValueError("scale values should be positive")
  603. else:
  604. raise TypeError("scale should be a list or tuple of length 2.")
  605. kwargs["scale"] = scale
  606. if shear is not None:
  607. if isinstance(shear, numbers.Number):
  608. if shear < 0:
  609. raise ValueError("If shear is a single number, it must be positive.")
  610. shear = (-1 * shear, shear)
  611. elif isinstance(shear, (tuple, list)) and (len(shear) == 2 or len(shear) == 4):
  612. # X-Axis shear with [min, max]
  613. if len(shear) == 2:
  614. shear = [shear[0], shear[1], 0., 0.]
  615. elif len(shear) == 4:
  616. shear = [s for s in shear]
  617. else:
  618. raise TypeError("shear should be a list or tuple and it must be of length 2 or 4.")
  619. kwargs["shear"] = shear
  620. if resample is not None:
  621. check_inter_mode(resample)
  622. kwargs["resample"] = resample
  623. if fill_value is not None:
  624. fill_value = check_fill_value(fill_value)
  625. kwargs["fill_value"] = fill_value
  626. return method(self, **kwargs)
  627. return new_method
  628. def check_rescale(method):
  629. """Wrapper method to check the parameters of rescale."""
  630. @wraps(method)
  631. def new_method(self, *args, **kwargs):
  632. rescale, shift = (list(args) + 2 * [None])[:2]
  633. if "rescale" in kwargs:
  634. rescale = kwargs.get("rescale")
  635. if "shift" in kwargs:
  636. shift = kwargs.get("shift")
  637. if rescale is None:
  638. raise ValueError("rescale is not provided.")
  639. check_pos_float32(rescale)
  640. kwargs["rescale"] = rescale
  641. if shift is None:
  642. raise ValueError("shift is not provided.")
  643. if not isinstance(shift, numbers.Number):
  644. raise TypeError("shift is not a number.")
  645. kwargs["shift"] = shift
  646. return method(self, **kwargs)
  647. return new_method
  648. def check_uniform_augment_cpp(method):
  649. """Wrapper method to check the parameters of UniformAugment cpp op."""
  650. @wraps(method)
  651. def new_method(self, *args, **kwargs):
  652. operations, num_ops = (list(args) + 2 * [None])[:2]
  653. if "operations" in kwargs:
  654. operations = kwargs.get("operations")
  655. else:
  656. raise ValueError("operations list required")
  657. if "num_ops" in kwargs:
  658. num_ops = kwargs.get("num_ops")
  659. else:
  660. num_ops = 2
  661. if not isinstance(num_ops, int):
  662. raise ValueError("Number of operations should be an integer.")
  663. if num_ops <= 0:
  664. raise ValueError("num_ops should be greater than zero")
  665. if num_ops > len(operations):
  666. raise ValueError("num_ops is greater than operations list size")
  667. if not isinstance(operations, list):
  668. raise TypeError("operations is not a python list")
  669. for op in operations:
  670. if not isinstance(op, TensorOp):
  671. raise ValueError("operations list only accepts C++ operations.")
  672. kwargs["num_ops"] = num_ops
  673. kwargs["operations"] = operations
  674. return method(self, **kwargs)
  675. return new_method
  676. def check_bounding_box_augment_cpp(method):
  677. """Wrapper method to check the parameters of BoundingBoxAugment cpp op."""
  678. @wraps(method)
  679. def new_method(self, *args, **kwargs):
  680. transform, ratio = (list(args) + 2 * [None])[:2]
  681. if "transform" in kwargs:
  682. transform = kwargs.get("transform")
  683. if "ratio" in kwargs:
  684. ratio = kwargs.get("ratio")
  685. if not isinstance(ratio, float) and not isinstance(ratio, int):
  686. raise ValueError("Ratio should be an int or float.")
  687. if ratio is not None:
  688. check_value(ratio, [0., 1.])
  689. kwargs["ratio"] = ratio
  690. else:
  691. ratio = 0.3
  692. if not isinstance(transform, TensorOp):
  693. raise ValueError("Transform can only be a C++ operation.")
  694. kwargs["transform"] = transform
  695. kwargs["ratio"] = ratio
  696. return method(self, **kwargs)
  697. return new_method
  698. def check_uniform_augment_py(method):
  699. """Wrapper method to check the parameters of python UniformAugment op."""
  700. @wraps(method)
  701. def new_method(self, *args, **kwargs):
  702. transforms, num_ops = (list(args) + 2 * [None])[:2]
  703. if "transforms" in kwargs:
  704. transforms = kwargs.get("transforms")
  705. if transforms is None:
  706. raise ValueError("transforms is not provided.")
  707. if not transforms:
  708. raise ValueError("transforms list is empty.")
  709. check_list(transforms)
  710. for transform in transforms:
  711. if isinstance(transform, TensorOp):
  712. raise ValueError("transform list only accepts Python operations.")
  713. kwargs["transforms"] = transforms
  714. if "num_ops" in kwargs:
  715. num_ops = kwargs.get("num_ops")
  716. if num_ops is not None:
  717. check_type(num_ops, int)
  718. check_positive(num_ops)
  719. if num_ops > len(transforms):
  720. raise ValueError("num_ops cannot be greater than the length of transforms list.")
  721. kwargs["num_ops"] = num_ops
  722. return method(self, **kwargs)
  723. return new_method
  724. def check_positive_degrees(method):
  725. """A wrapper method to check degrees parameter in RandSharpness and RandColor"""
  726. @wraps(method)
  727. def new_method(self, *args, **kwargs):
  728. degrees = (list(args) + [None])[0]
  729. if "degrees" in kwargs:
  730. degrees = kwargs.get("degrees")
  731. if degrees is not None:
  732. if isinstance(degrees, (list, tuple)):
  733. if len(degrees) != 2:
  734. raise ValueError("Degrees must be a sequence with length 2.")
  735. if degrees[0] < 0:
  736. raise ValueError("Degrees range must be non-negative.")
  737. if degrees[0] > degrees[1]:
  738. raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
  739. else:
  740. raise TypeError("Degrees must be a sequence in (min,max) format.")
  741. return method(self, **kwargs)
  742. return new_method
  743. def check_compose_list(method):
  744. """Wrapper method to check the transform list of ComposeOp."""
  745. @wraps(method)
  746. def new_method(self, *args, **kwargs):
  747. transforms = (list(args) + [None])[0]
  748. if "transforms" in kwargs:
  749. transforms = kwargs.get("transforms")
  750. if transforms is None:
  751. raise ValueError("transforms is not provided.")
  752. if not transforms:
  753. raise ValueError("transforms list is empty.")
  754. if not isinstance(transforms, list):
  755. raise TypeError("transforms is not a python list")
  756. kwargs["transforms"] = transforms
  757. return method(self, **kwargs)
  758. return new_method