You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_transforms_util.py 56 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in py_transforms_utils functions.
  16. """
  17. import io
  18. import math
  19. import numbers
  20. import random
  21. import colorsys
  22. import numpy as np
  23. from PIL import Image, ImageOps, ImageEnhance, __version__
  24. from .utils import Inter
  25. from ..core.py_util_helpers import is_numpy
  26. augment_error_message = "img should be PIL image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data."
  27. def is_pil(img):
  28. """
  29. Check if the input image is PIL format.
  30. Args:
  31. img: Image to be checked.
  32. Returns:
  33. Bool, True if input is PIL image.
  34. """
  35. return isinstance(img, Image.Image)
  36. def normalize(img, mean, std, pad_channel=False, dtype="float32"):
  37. """
  38. Normalize the image between [0, 1] with respect to mean and standard deviation.
  39. Args:
  40. img (numpy.ndarray): Image array of shape CHW to be normalized.
  41. mean (list): List of mean values for each channel, w.r.t channel order.
  42. std (list): List of standard deviations for each channel, w.r.t. channel order.
  43. pad_channel (bool): Whether to pad a extra channel with value zero.
  44. dtype (str): Output datatype of normalize, only worked when pad_channel is True. (default is "float32")
  45. Returns:
  46. img (numpy.ndarray), Normalized image.
  47. """
  48. if np.issubdtype(img.dtype, np.integer):
  49. raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]."
  50. .format(img.dtype))
  51. if not is_numpy(img):
  52. raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
  53. num_channels = img.shape[0] # shape is (C, H, W)
  54. if len(mean) != len(std):
  55. raise ValueError("Length of mean and std must be equal.")
  56. # if length equal to 1, adjust the mean and std arrays to have the correct
  57. # number of channels (replicate the values)
  58. if len(mean) == 1:
  59. mean = [mean[0]] * num_channels
  60. std = [std[0]] * num_channels
  61. elif len(mean) != num_channels:
  62. raise ValueError("Length of mean and std must both be 1 or equal to the number of channels({0})."
  63. .format(num_channels))
  64. mean = np.array(mean, dtype=img.dtype)
  65. std = np.array(std, dtype=img.dtype)
  66. image = (img - mean[:, None, None]) / std[:, None, None]
  67. if pad_channel:
  68. zeros = np.zeros([1, image.shape[1], image.shape[2]], dtype=np.float32)
  69. image = np.concatenate((image, zeros), axis=0)
  70. if dtype == "float16":
  71. image = image.astype(np.float16)
  72. return image
  73. def decode(img):
  74. """
  75. Decode the input image to PIL image format in RGB mode.
  76. Args:
  77. img: Image to be decoded.
  78. Returns:
  79. img (PIL image), Decoded image in RGB mode.
  80. """
  81. try:
  82. data = io.BytesIO(img)
  83. img = Image.open(data)
  84. return img.convert('RGB')
  85. except IOError as e:
  86. raise ValueError("{0}\nWARNING: Failed to decode given image.".format(e))
  87. except AttributeError as e:
  88. raise ValueError("{0}\nWARNING: Failed to decode, Image might already be decoded.".format(e))
  89. def hwc_to_chw(img):
  90. """
  91. Transpose the input image; shape (H, W, C) to shape (C, H, W).
  92. Args:
  93. img (numpy.ndarray): Image to be converted.
  94. Returns:
  95. img (numpy.ndarray), Converted image.
  96. """
  97. if is_numpy(img):
  98. return img.transpose(2, 0, 1).copy()
  99. raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
  100. def to_tensor(img, output_type):
  101. """
  102. Change the input image (PIL image or NumPy image array) to NumPy format.
  103. Args:
  104. img (Union[PIL image, numpy.ndarray]): Image to be converted.
  105. output_type: The datatype of the NumPy output. e.g. np.float32
  106. Returns:
  107. img (numpy.ndarray), Converted image.
  108. """
  109. if not (is_pil(img) or is_numpy(img)):
  110. raise TypeError("img should be PIL image or NumPy array. Got {}.".format(type(img)))
  111. img = np.asarray(img)
  112. if img.ndim not in (2, 3):
  113. raise ValueError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
  114. if img.ndim == 2:
  115. img = img[:, :, None]
  116. img = hwc_to_chw(img)
  117. img = img / 255.
  118. return to_type(img, output_type)
  119. def to_pil(img):
  120. """
  121. Convert the input image to PIL format.
  122. Args:
  123. img: Image to be converted.
  124. Returns:
  125. img (PIL image), Converted image.
  126. """
  127. if not is_pil(img):
  128. return Image.fromarray(img)
  129. return img
  130. def horizontal_flip(img):
  131. """
  132. Flip the input image horizontally.
  133. Args:
  134. img (PIL image): Image to be flipped horizontally.
  135. Returns:
  136. img (PIL image), Horizontally flipped image.
  137. """
  138. if not is_pil(img):
  139. raise TypeError(augment_error_message.format(type(img)))
  140. return img.transpose(Image.FLIP_LEFT_RIGHT)
  141. def vertical_flip(img):
  142. """
  143. Flip the input image vertically.
  144. Args:
  145. img (PIL image): Image to be flipped vertically.
  146. Returns:
  147. img (PIL image), Vertically flipped image.
  148. """
  149. if not is_pil(img):
  150. raise TypeError(augment_error_message.format(type(img)))
  151. return img.transpose(Image.FLIP_TOP_BOTTOM)
  152. def random_horizontal_flip(img, prob):
  153. """
  154. Randomly flip the input image horizontally.
  155. Args:
  156. img (PIL image): Image to be flipped.
  157. If the given probability is above the random probability, then the image is flipped.
  158. prob (float): Probability of the image being flipped.
  159. Returns:
  160. img (PIL image), Converted image.
  161. """
  162. if not is_pil(img):
  163. raise TypeError(augment_error_message.format(type(img)))
  164. if prob > random.random():
  165. img = horizontal_flip(img)
  166. return img
  167. def random_vertical_flip(img, prob):
  168. """
  169. Randomly flip the input image vertically.
  170. Args:
  171. img (PIL image): Image to be flipped.
  172. If the given probability is above the random probability, then the image is flipped.
  173. prob (float): Probability of the image being flipped.
  174. Returns:
  175. img (PIL image), Converted image.
  176. """
  177. if not is_pil(img):
  178. raise TypeError(augment_error_message.format(type(img)))
  179. if prob > random.random():
  180. img = vertical_flip(img)
  181. return img
  182. def crop(img, top, left, height, width):
  183. """
  184. Crop the input PIL image.
  185. Args:
  186. img (PIL image): Image to be cropped. (0,0) denotes the top left corner of the image,
  187. in the directions of (width, height).
  188. top (int): Vertical component of the top left corner of the crop box.
  189. left (int): Horizontal component of the top left corner of the crop box.
  190. height (int): Height of the crop box.
  191. width (int): Width of the crop box.
  192. Returns:
  193. img (PIL image), Cropped image.
  194. """
  195. if not is_pil(img):
  196. raise TypeError(augment_error_message.format(type(img)))
  197. return img.crop((left, top, left + width, top + height))
  198. def resize(img, size, interpolation=Inter.BILINEAR):
  199. """
  200. Resize the input PIL image to desired size.
  201. Args:
  202. img (PIL image): Image to be resized.
  203. size (Union[int, sequence]): The output size of the resized image.
  204. If size is an integer, smaller edge of the image will be resized to this value with
  205. the same image aspect ratio.
  206. If size is a sequence of (height, width), this will be the desired output size.
  207. interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2.
  208. Returns:
  209. img (PIL image), Resized image.
  210. """
  211. if not is_pil(img):
  212. raise TypeError(augment_error_message.format(type(img)))
  213. if not (isinstance(size, int) or (isinstance(size, (list, tuple)) and len(size) == 2)):
  214. raise TypeError('Size should be a single number or a list/tuple (h, w) of length 2.'
  215. 'Got {}.'.format(size))
  216. if isinstance(size, int):
  217. img_width, img_height = img.size
  218. aspect_ratio = img_width / img_height # maintain the aspect ratio
  219. if (img_width <= img_height and img_width == size) or \
  220. (img_height <= img_width and img_height == size):
  221. return img
  222. if img_width < img_height:
  223. out_width = size
  224. out_height = int(size / aspect_ratio)
  225. return img.resize((out_width, out_height), interpolation)
  226. out_height = size
  227. out_width = int(size * aspect_ratio)
  228. return img.resize((out_width, out_height), interpolation)
  229. return img.resize(size[::-1], interpolation)
  230. def center_crop(img, size):
  231. """
  232. Crop the input PIL image at the center to the given size.
  233. Args:
  234. img (PIL image): Image to be cropped.
  235. size (Union[int, tuple]): The size of the crop box.
  236. If size is an integer, a square crop of size (size, size) is returned.
  237. If size is a sequence of length 2, it should be (height, width).
  238. Returns:
  239. img (PIL image), Cropped image.
  240. """
  241. if not is_pil(img):
  242. raise TypeError(augment_error_message.format(type(img)))
  243. if isinstance(size, int):
  244. size = (size, size)
  245. img_width, img_height = img.size
  246. crop_height, crop_width = size
  247. crop_top = int(round((img_height - crop_height) / 2.))
  248. crop_left = int(round((img_width - crop_width) / 2.))
  249. return crop(img, crop_top, crop_left, crop_height, crop_width)
  250. def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, max_attempts=10):
  251. """
  252. Crop the input PIL image to a random size and aspect ratio.
  253. Args:
  254. img (PIL image): Image to be randomly cropped and resized.
  255. size (Union[int, sequence]): The size of the output image.
  256. If size is an integer, a square crop of size (size, size) is returned.
  257. If size is a sequence of length 2, it should be (height, width).
  258. scale (tuple): Range (min, max) of respective size of the original size to be cropped.
  259. ratio (tuple): Range (min, max) of aspect ratio to be cropped.
  260. interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2.
  261. max_attempts (int): The maximum number of attempts to propose a valid crop_area. Default 10.
  262. If exceeded, fall back to use center_crop instead.
  263. Returns:
  264. img (PIL image), Randomly cropped and resized image.
  265. """
  266. if not is_pil(img):
  267. raise TypeError(augment_error_message.format(type(img)))
  268. if isinstance(size, int):
  269. size = (size, size)
  270. elif isinstance(size, (tuple, list)) and len(size) == 2:
  271. size = size
  272. else:
  273. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  274. if scale[0] > scale[1] or ratio[0] > ratio[1]:
  275. raise ValueError("Range should be in the order of (min, max).")
  276. def _input_to_factor(img, scale, ratio):
  277. img_width, img_height = img.size
  278. img_area = img_width * img_height
  279. for _ in range(max_attempts):
  280. crop_area = random.uniform(scale[0], scale[1]) * img_area
  281. # in case of non-symmetrical aspect ratios,
  282. # use uniform distribution on a logarithmic scale.
  283. log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
  284. aspect_ratio = math.exp(random.uniform(*log_ratio))
  285. width = int(round(math.sqrt(crop_area * aspect_ratio)))
  286. height = int(round(width / aspect_ratio))
  287. if 0 < width <= img_width and 0 < height <= img_height:
  288. top = random.randint(0, img_height - height)
  289. left = random.randint(0, img_width - width)
  290. return top, left, height, width
  291. # exceeding max_attempts, use center crop
  292. img_ratio = img_width / img_height
  293. if img_ratio < ratio[0]:
  294. width = img_width
  295. height = int(round(width / ratio[0]))
  296. elif img_ratio > ratio[1]:
  297. height = img_height
  298. width = int(round(height * ratio[1]))
  299. else:
  300. width = img_width
  301. height = img_height
  302. top = int(round((img_height - height) / 2.))
  303. left = int(round((img_width - width) / 2.))
  304. return top, left, height, width
  305. top, left, height, width = _input_to_factor(img, scale, ratio)
  306. img = crop(img, top, left, height, width)
  307. img = resize(img, size, interpolation)
  308. return img
  309. def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode):
  310. """
  311. Crop the input PIL image at a random location.
  312. Args:
  313. img (PIL image): Image to be randomly cropped.
  314. size (Union[int, sequence]): The output size of the cropped image.
  315. If size is an integer, a square crop of size (size, size) is returned.
  316. If size is a sequence of length 2, it should be (height, width).
  317. padding (Union[int, sequence], optional): The number of pixels to pad the image.
  318. If a single number is provided, it pads all borders with this value.
  319. If a tuple or list of 2 values are provided, it pads the (left and top)
  320. with the first value and (right and bottom) with the second value.
  321. If 4 values are provided as a list or tuple,
  322. it pads the left, top, right and bottom respectively.
  323. Default is None.
  324. pad_if_needed (bool): Pad the image if either side is smaller than
  325. the given output size. Default is False.
  326. fill_value (Union[int, tuple]): The pixel intensity of the borders if
  327. the padding_mode is 'constant'. If it is a 3-tuple, it is used to
  328. fill R, G, B channels respectively.
  329. padding_mode (str): The method of padding. Can be any of
  330. ['constant', 'edge', 'reflect', 'symmetric'].
  331. - 'constant', means it fills the border with constant values
  332. - 'edge', means it pads with the last value on the edge
  333. - 'reflect', means it reflects the values on the edge omitting the last
  334. value of edge
  335. - 'symmetric', means it reflects the values on the edge repeating the last
  336. value of edge
  337. Returns:
  338. PIL image, Cropped image.
  339. """
  340. if not is_pil(img):
  341. raise TypeError(augment_error_message.format(type(img)))
  342. if isinstance(size, int):
  343. size = (size, size)
  344. elif isinstance(size, (tuple, list)) and len(size) == 2:
  345. size = size
  346. else:
  347. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  348. def _input_to_factor(img, size):
  349. img_width, img_height = img.size
  350. height, width = size
  351. if height > img_height or width > img_width:
  352. raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
  353. if width == img_width and height == img_height:
  354. return 0, 0, img_height, img_width
  355. top = random.randint(0, img_height - height)
  356. left = random.randint(0, img_width - width)
  357. return top, left, height, width
  358. if padding is not None:
  359. img = pad(img, padding, fill_value, padding_mode)
  360. # pad width when needed, img.size (width, height), crop size (height, width)
  361. if pad_if_needed and img.size[0] < size[1]:
  362. img = pad(img, (size[1] - img.size[0], 0), fill_value, padding_mode)
  363. # pad height when needed
  364. if pad_if_needed and img.size[1] < size[0]:
  365. img = pad(img, (0, size[0] - img.size[1]), fill_value, padding_mode)
  366. top, left, height, width = _input_to_factor(img, size)
  367. return crop(img, top, left, height, width)
  368. def adjust_brightness(img, brightness_factor):
  369. """
  370. Adjust brightness of an image.
  371. Args:
  372. img (PIL image): Image to be adjusted.
  373. brightness_factor (float): A non negative number indicated the factor by which
  374. the brightness is adjusted. 0 gives a black image, 1 gives the original.
  375. Returns:
  376. img (PIL image), Brightness adjusted image.
  377. """
  378. if not is_pil(img):
  379. raise TypeError(augment_error_message.format(type(img)))
  380. enhancer = ImageEnhance.Brightness(img)
  381. img = enhancer.enhance(brightness_factor)
  382. return img
  383. def adjust_contrast(img, contrast_factor):
  384. """
  385. Adjust contrast of an image.
  386. Args:
  387. img (PIL image): PIL image to be adjusted.
  388. contrast_factor (float): A non negative number indicated the factor by which
  389. the contrast is adjusted. 0 gives a solid gray image, 1 gives the original.
  390. Returns:
  391. img (PIL image), Contrast adjusted image.
  392. """
  393. if not is_pil(img):
  394. raise TypeError(augment_error_message.format(type(img)))
  395. enhancer = ImageEnhance.Contrast(img)
  396. img = enhancer.enhance(contrast_factor)
  397. return img
  398. def adjust_saturation(img, saturation_factor):
  399. """
  400. Adjust saturation of an image.
  401. Args:
  402. img (PIL image): PIL image to be adjusted.
  403. saturation_factor (float): A non negative number indicated the factor by which
  404. the saturation is adjusted. 0 will give a black and white image, 1 will
  405. give the original.
  406. Returns:
  407. img (PIL image), Saturation adjusted image.
  408. """
  409. if not is_pil(img):
  410. raise TypeError(augment_error_message.format(type(img)))
  411. enhancer = ImageEnhance.Color(img)
  412. img = enhancer.enhance(saturation_factor)
  413. return img
  414. def adjust_hue(img, hue_factor):
  415. """
  416. Adjust hue of an image. The Hue is changed by changing the HSV values after image is converted to HSV.
  417. Args:
  418. img (PIL image): PIL image to be adjusted.
  419. hue_factor (float): Amount to shift the Hue channel. Value should be in
  420. [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel. This
  421. is because Hue wraps around when rotated 360 degrees.
  422. 0 means no shift that gives the original image while both -0.5 and 0.5
  423. will give an image with complementary colors .
  424. Returns:
  425. img (PIL image), Hue adjusted image.
  426. """
  427. image = img
  428. image_hue_factor = hue_factor
  429. if not -0.5 <= image_hue_factor <= 0.5:
  430. raise ValueError('image_hue_factor {} is not in [-0.5, 0.5].'.format(image_hue_factor))
  431. if not is_pil(image):
  432. raise TypeError(augment_error_message.format(type(image)))
  433. mode = image.mode
  434. if mode in {'L', '1', 'I', 'F'}:
  435. return image
  436. hue, saturation, value = img.convert('HSV').split()
  437. np_hue = np.array(hue, dtype=np.uint8)
  438. with np.errstate(over='ignore'):
  439. np_hue += np.uint8(image_hue_factor * 255)
  440. hue = Image.fromarray(np_hue, 'L')
  441. image = Image.merge('HSV', (hue, saturation, value)).convert(mode)
  442. return image
  443. def to_type(img, output_type):
  444. """
  445. Convert the NumPy image array to desired NumPy dtype.
  446. Args:
  447. img (numpy): NumPy image to cast to desired NumPy dtype.
  448. output_type (Numpy datatype): NumPy dtype to cast to.
  449. Returns:
  450. img (numpy.ndarray), Converted image.
  451. """
  452. if not is_numpy(img):
  453. raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
  454. return img.astype(output_type)
  455. def rotate(img, angle, resample, expand, center, fill_value):
  456. """
  457. Rotate the input PIL image by angle.
  458. Args:
  459. img (PIL image): Image to be rotated.
  460. angle (int or float): Rotation angle in degrees, counter-clockwise.
  461. resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
  462. If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
  463. expand (bool, optional): Optional expansion flag. If set to True, expand the output
  464. image to make it large enough to hold the entire rotated image.
  465. If set to False or omitted, make the output image the same size as the input.
  466. Note that the expand flag assumes rotation around the center and no translation.
  467. center (tuple, optional): Optional center of rotation (a 2-tuple).
  468. Origin is the top left corner.
  469. fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
  470. If it is a 3-tuple, it is used for R, G, B channels respectively.
  471. If it is an integer, it is used for all RGB channels.
  472. Returns:
  473. img (PIL image), Rotated image.
  474. https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
  475. """
  476. if not is_pil(img):
  477. raise TypeError(augment_error_message.format(type(img)))
  478. if isinstance(fill_value, int):
  479. fill_value = tuple([fill_value] * 3)
  480. return img.rotate(angle, resample, expand, center, fillcolor=fill_value)
  481. def random_color_adjust(img, brightness, contrast, saturation, hue):
  482. """
  483. Randomly adjust the brightness, contrast, saturation, and hue of an image.
  484. Args:
  485. img (PIL image): Image to have its color adjusted randomly.
  486. brightness (Union[float, tuple]): Brightness adjustment factor. Cannot be negative.
  487. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness].
  488. If it is a sequence, it should be [min, max] for the range.
  489. contrast (Union[float, tuple]): Contrast adjustment factor. Cannot be negative.
  490. If it is a float, the factor is uniformly chosen from the range [max(0, 1-contrast), 1+contrast].
  491. If it is a sequence, it should be [min, max] for the range.
  492. saturation (Union[float, tuple]): Saturation adjustment factor. Cannot be negative.
  493. If it is a float, the factor is uniformly chosen from the range [max(0, 1-saturation), 1+saturation].
  494. If it is a sequence, it should be [min, max] for the range.
  495. hue (Union[float, tuple]): Hue adjustment factor.
  496. If it is a float, the range will be [-hue, hue]. Value should be 0 <= hue <= 0.5.
  497. If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5.
  498. Returns:
  499. img (PIL image), Image after random adjustment of its color.
  500. """
  501. if not is_pil(img):
  502. raise TypeError(augment_error_message.format(type(img)))
  503. def _input_to_factor(value, input_name, center=1, bound=(0, float('inf')), non_negative=True):
  504. if isinstance(value, numbers.Number):
  505. if value < 0:
  506. raise ValueError("The input value of {} cannot be negative.".format(input_name))
  507. # convert value into a range
  508. value = [center - value, center + value]
  509. if non_negative:
  510. value[0] = max(0, value[0])
  511. elif isinstance(value, (list, tuple)) and len(value) == 2:
  512. if not bound[0] <= value[0] <= value[1] <= bound[1]:
  513. raise ValueError("Please check your value range of {} is valid and "
  514. "within the bound {}.".format(input_name, bound))
  515. else:
  516. raise TypeError("Input of {} should be either a single value, or a list/tuple of "
  517. "length 2.".format(input_name))
  518. factor = random.uniform(value[0], value[1])
  519. return factor
  520. brightness_factor = _input_to_factor(brightness, 'brightness')
  521. contrast_factor = _input_to_factor(contrast, 'contrast')
  522. saturation_factor = _input_to_factor(saturation, 'saturation')
  523. hue_factor = _input_to_factor(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  524. transforms = []
  525. transforms.append(lambda img: adjust_brightness(img, brightness_factor))
  526. transforms.append(lambda img: adjust_contrast(img, contrast_factor))
  527. transforms.append(lambda img: adjust_saturation(img, saturation_factor))
  528. transforms.append(lambda img: adjust_hue(img, hue_factor))
  529. # apply color adjustments in a random order
  530. random.shuffle(transforms)
  531. for transform in transforms:
  532. img = transform(img)
  533. return img
  534. def random_rotation(img, degrees, resample, expand, center, fill_value):
  535. """
  536. Rotate the input PIL image by a random angle.
  537. Args:
  538. img (PIL image): Image to be rotated.
  539. degrees (Union[int, float, sequence]): Range of random rotation degrees.
  540. If degrees is a number, the range will be converted to (-degrees, degrees).
  541. If degrees is a sequence, it should be (min, max).
  542. resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
  543. If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
  544. expand (bool, optional): Optional expansion flag. If set to True, expand the output
  545. image to make it large enough to hold the entire rotated image.
  546. If set to False or omitted, make the output image the same size as the input.
  547. Note that the expand flag assumes rotation around the center and no translation.
  548. center (tuple, optional): Optional center of rotation (a 2-tuple).
  549. Origin is the top left corner.
  550. fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
  551. If it is a 3-tuple, it is used for R, G, B channels respectively.
  552. If it is an integer, it is used for all RGB channels.
  553. Returns:
  554. img (PIL image), Rotated image.
  555. https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
  556. """
  557. if not is_pil(img):
  558. raise TypeError(augment_error_message.format(type(img)))
  559. if isinstance(degrees, numbers.Number):
  560. if degrees < 0:
  561. raise ValueError("If degrees is a single number, it cannot be negative.")
  562. degrees = (-degrees, degrees)
  563. elif isinstance(degrees, (list, tuple)):
  564. if len(degrees) != 2:
  565. raise ValueError("If degrees is a sequence, the length must be 2.")
  566. else:
  567. raise TypeError("Degrees must be a single non-negative number or a sequence.")
  568. angle = random.uniform(degrees[0], degrees[1])
  569. return rotate(img, angle, resample, expand, center, fill_value)
  570. def five_crop(img, size):
  571. """
  572. Generate 5 cropped images (one central and four corners).
  573. Args:
  574. img (PIL image): PIL image to be cropped.
  575. size (Union[int, sequence]): The output size of the crop.
  576. If size is an integer, a square crop of size (size, size) is returned.
  577. If size is a sequence of length 2, it should be (height, width).
  578. Returns:
  579. img_tuple (tuple), a tuple of 5 PIL images
  580. (top_left, top_right, bottom_left, bottom_right, center).
  581. """
  582. if not is_pil(img):
  583. raise TypeError(augment_error_message.format(type(img)))
  584. if isinstance(size, int):
  585. size = (size, size)
  586. elif isinstance(size, (tuple, list)) and len(size) == 2:
  587. size = size
  588. else:
  589. raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
  590. # PIL image.size returns in (width, height) order
  591. img_width, img_height = img.size
  592. crop_height, crop_width = size
  593. if crop_height > img_height or crop_width > img_width:
  594. raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
  595. center = center_crop(img, (crop_height, crop_width))
  596. top_left = img.crop((0, 0, crop_width, crop_height))
  597. top_right = img.crop((img_width - crop_width, 0, img_width, crop_height))
  598. bottom_left = img.crop((0, img_height - crop_height, crop_width, img_height))
  599. bottom_right = img.crop((img_width - crop_width, img_height - crop_height, img_width, img_height))
  600. return top_left, top_right, bottom_left, bottom_right, center
  601. def ten_crop(img, size, use_vertical_flip=False):
  602. """
  603. Generate 10 cropped images (first 5 from FiveCrop, second 5 from their flipped version).
  604. The default is horizontal flipping, use_vertical_flip=False.
  605. Args:
  606. img (PIL image): PIL image to be cropped.
  607. size (Union[int, sequence]): The output size of the crop.
  608. If size is an integer, a square crop of size (size, size) is returned.
  609. If size is a sequence of length 2, it should be (height, width).
  610. use_vertical_flip (bool): Flip the image vertically instead of horizontally if set to True.
  611. Returns:
  612. img_tuple (tuple), a tuple of 10 PIL images
  613. (top_left, top_right, bottom_left, bottom_right, center) of original image +
  614. (top_left, top_right, bottom_left, bottom_right, center) of flipped image.
  615. """
  616. if not is_pil(img):
  617. raise TypeError(augment_error_message.format(type(img)))
  618. if isinstance(size, int):
  619. size = (size, size)
  620. elif isinstance(size, (tuple, list)) and len(size) == 2:
  621. size = size
  622. else:
  623. raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
  624. first_five_crop = five_crop(img, size)
  625. if use_vertical_flip:
  626. img = vertical_flip(img)
  627. else:
  628. img = horizontal_flip(img)
  629. second_five_crop = five_crop(img, size)
  630. return first_five_crop + second_five_crop
  631. def grayscale(img, num_output_channels):
  632. """
  633. Convert the input PIL image to grayscale image.
  634. Args:
  635. img (PIL image): PIL image to be converted to grayscale.
  636. num_output_channels (int): Number of channels of the output grayscale image (1 or 3).
  637. Returns:
  638. img (PIL image), grayscaled image.
  639. """
  640. if not is_pil(img):
  641. raise TypeError(augment_error_message.format(type(img)))
  642. if num_output_channels == 1:
  643. img = img.convert('L')
  644. elif num_output_channels == 3:
  645. # each channel is the same grayscale layer
  646. img = img.convert('L')
  647. np_gray = np.array(img, dtype=np.uint8)
  648. np_img = np.dstack([np_gray, np_gray, np_gray])
  649. img = Image.fromarray(np_img, 'RGB')
  650. else:
  651. raise ValueError('num_output_channels should be either 1 or 3. Got {}.'.format(num_output_channels))
  652. return img
  653. def pad(img, padding, fill_value, padding_mode):
  654. """
  655. Pads the image according to padding parameters.
  656. Args:
  657. img (PIL image): Image to be padded.
  658. padding (Union[int, sequence], optional): The number of pixels to pad the image.
  659. If a single number is provided, it pads all borders with this value.
  660. If a tuple or list of 2 values are provided, it pads the (left and top)
  661. with the first value and (right and bottom) with the second value.
  662. If 4 values are provided as a list or tuple,
  663. it pads the left, top, right and bottom respectively.
  664. Default is None.
  665. fill_value (Union[int, tuple]): The pixel intensity of the borders if
  666. the padding_mode is "constant". If it is a 3-tuple, it is used to
  667. fill R, G, B channels respectively.
  668. padding_mode (str): The method of padding. Can be any of
  669. ['constant', 'edge', 'reflect', 'symmetric'].
  670. - 'constant', means it fills the border with constant values
  671. - 'edge', means it pads with the last value on the edge
  672. - 'reflect', means it reflects the values on the edge omitting the last
  673. value of edge
  674. - 'symmetric', means it reflects the values on the edge repeating the last
  675. value of edge
  676. Returns:
  677. img (PIL image), Padded image.
  678. """
  679. if not is_pil(img):
  680. raise TypeError(augment_error_message.format(type(img)))
  681. if isinstance(padding, numbers.Number):
  682. top = bottom = left = right = padding
  683. elif isinstance(padding, (tuple, list)):
  684. if len(padding) == 2:
  685. left = right = padding[0]
  686. top = bottom = padding[1]
  687. elif len(padding) == 4:
  688. left = padding[0]
  689. top = padding[1]
  690. right = padding[2]
  691. bottom = padding[3]
  692. else:
  693. raise ValueError("The size of the padding list or tuple should be 2 or 4.")
  694. else:
  695. raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.")
  696. if not isinstance(fill_value, (numbers.Number, str, tuple)):
  697. raise TypeError("fill_value can be any of: an integer, a string or a tuple.")
  698. if padding_mode not in ['constant', 'edge', 'reflect', 'symmetric']:
  699. raise ValueError("Padding mode should be 'constant', 'edge', 'reflect', or 'symmetric'.")
  700. if padding_mode == 'constant':
  701. if img.mode == 'P':
  702. palette = img.getpalette()
  703. image = ImageOps.expand(img, border=padding, fill=fill_value)
  704. image.putpalette(palette)
  705. return image
  706. return ImageOps.expand(img, border=padding, fill=fill_value)
  707. if img.mode == 'P':
  708. palette = img.getpalette()
  709. img = np.asarray(img)
  710. img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
  711. img = Image.fromarray(img)
  712. img.putpalette(palette)
  713. return img
  714. img = np.asarray(img)
  715. if len(img.shape) == 3:
  716. img = np.pad(img, ((top, bottom), (left, right), (0, 0)), padding_mode)
  717. if len(img.shape) == 2:
  718. img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
  719. return Image.fromarray(img)
  720. def get_perspective_params(img, distortion_scale):
  721. """Helper function to get parameters for RandomPerspective.
  722. """
  723. img_width, img_height = img.size
  724. distorted_half_width = int(img_width / 2 * distortion_scale)
  725. distorted_half_height = int(img_height / 2 * distortion_scale)
  726. top_left = (random.randint(0, distorted_half_width),
  727. random.randint(0, distorted_half_height))
  728. top_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
  729. random.randint(0, distorted_half_height))
  730. bottom_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
  731. random.randint(img_height - distorted_half_height - 1, img_height - 1))
  732. bottom_left = (random.randint(0, distorted_half_width),
  733. random.randint(img_height - distorted_half_height - 1, img_height - 1))
  734. start_points = [(0, 0), (img_width - 1, 0), (img_width - 1, img_height - 1), (0, img_height - 1)]
  735. end_points = [top_left, top_right, bottom_right, bottom_left]
  736. return start_points, end_points
  737. def perspective(img, start_points, end_points, interpolation=Inter.BICUBIC):
  738. """
  739. Apply perspective transformation to the input PIL image.
  740. Args:
  741. img (PIL image): PIL image to be applied perspective transformation.
  742. start_points (list): List of [top_left, top_right, bottom_right, bottom_left] of the original image.
  743. end_points: List of [top_left, top_right, bottom_right, bottom_left] of the transformed image.
  744. interpolation (interpolation mode): Image interpolation mode, Default is Inter.BICUBIC = 3.
  745. Returns:
  746. img (PIL image), Image after being perspectively transformed.
  747. """
  748. def _input_to_coeffs(original_points, transformed_points):
  749. # Get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
  750. # According to "Using Projective Geometry to Correct a Camera" from AMS.
  751. # http://www.ams.org/publicoutreach/feature-column/fc-2013-03
  752. # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Geometry.c#L377
  753. matrix = []
  754. for pt1, pt2 in zip(transformed_points, original_points):
  755. matrix.append([pt1[0], pt1[1], 1, 0, 0, 0, -pt2[0] * pt1[0], -pt2[0] * pt1[1]])
  756. matrix.append([0, 0, 0, pt1[0], pt1[1], 1, -pt2[1] * pt1[0], -pt2[1] * pt1[1]])
  757. matrix_a = np.array(matrix, dtype=np.float)
  758. matrix_b = np.array(original_points, dtype=np.float).reshape(8)
  759. res = np.linalg.lstsq(matrix_a, matrix_b, rcond=None)[0]
  760. return res.tolist()
  761. if not is_pil(img):
  762. raise TypeError(augment_error_message.format(type(img)))
  763. coeffs = _input_to_coeffs(start_points, end_points)
  764. return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
  765. def get_erase_params(np_img, scale, ratio, value, bounded, max_attempts):
  766. """Helper function to get parameters for RandomErasing/ Cutout.
  767. """
  768. if not is_numpy(np_img):
  769. raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
  770. image_c, image_h, image_w = np_img.shape
  771. area = image_h * image_w
  772. for _ in range(max_attempts):
  773. erase_area = random.uniform(scale[0], scale[1]) * area
  774. aspect_ratio = random.uniform(ratio[0], ratio[1])
  775. erase_w = int(round(math.sqrt(erase_area * aspect_ratio)))
  776. erase_h = int(round(erase_w / aspect_ratio))
  777. erase_shape = (image_c, erase_h, erase_w)
  778. if erase_h < image_h and erase_w < image_w:
  779. if bounded:
  780. i = random.randint(0, image_h - erase_h)
  781. j = random.randint(0, image_w - erase_w)
  782. else:
  783. def clip(x, lower, upper):
  784. return max(lower, min(x, upper))
  785. x = random.randint(0, image_w)
  786. y = random.randint(0, image_h)
  787. x1 = clip(x - erase_w // 2, 0, image_w)
  788. x2 = clip(x + erase_w // 2, 0, image_w)
  789. y1 = clip(y - erase_h // 2, 0, image_h)
  790. y2 = clip(y + erase_h // 2, 0, image_h)
  791. i, j, erase_h, erase_w = y1, x1, y2 - y1, x2 - x1
  792. if isinstance(value, numbers.Number):
  793. erase_value = value
  794. elif isinstance(value, (str, bytes)):
  795. erase_value = np.random.normal(loc=0.0, scale=1.0, size=erase_shape)
  796. elif isinstance(value, (tuple, list)) and len(value) == 3:
  797. value = np.array(value)
  798. erase_value = np.multiply(np.ones(erase_shape), value[:, None, None])
  799. else:
  800. raise ValueError("The value for erasing should be either a single value, or a string "
  801. "'random', or a sequence of 3 elements for RGB respectively.")
  802. return i, j, erase_h, erase_w, erase_value
  803. # exceeding max_attempts, return original image
  804. return 0, 0, image_h, image_w, np_img
  805. def erase(np_img, i, j, height, width, erase_value, inplace=False):
  806. """
  807. Erase the pixels, within a selected rectangle region, to the given value. Applied on the input NumPy image array.
  808. Args:
  809. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be erased.
  810. i (int): The height component of the top left corner (height, width).
  811. j (int): The width component of the top left corner (height, width).
  812. height (int): Height of the erased region.
  813. width (int): Width of the erased region.
  814. erase_value: Erase value return from helper function get_erase_params().
  815. inplace (bool, optional): Apply this transform inplace. Default is False.
  816. Returns:
  817. np_img (numpy.ndarray), Erased NumPy image array.
  818. """
  819. if not is_numpy(np_img):
  820. raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
  821. if not inplace:
  822. np_img = np_img.copy()
  823. # (i, j) here are the coordinates of axes (height, width) as in CHW
  824. np_img[:, i:i + height, j:j + width] = erase_value
  825. return np_img
  826. def linear_transform(np_img, transformation_matrix, mean_vector):
  827. """
  828. Apply linear transformation to the input NumPy image array, given a square transformation matrix and a mean_vector.
  829. The transformation first flattens the input array and subtract mean_vector from it, then computes the
  830. dot product with the transformation matrix, and reshapes it back to its original shape.
  831. Args:
  832. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be linear transformed.
  833. transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
  834. mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
  835. Returns:
  836. np_img (numpy.ndarray), Linear transformed image.
  837. """
  838. if not is_numpy(np_img):
  839. raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
  840. if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
  841. raise ValueError("transformation_matrix should be a square matrix. "
  842. "Got shape {} instead".format(transformation_matrix.shape))
  843. if np.prod(np_img.shape) != transformation_matrix.shape[0]:
  844. raise ValueError("transformation_matrix shape {0} not compatible with "
  845. "Numpy image shape {1}.".format(transformation_matrix.shape, np_img.shape))
  846. if mean_vector.shape[0] != transformation_matrix.shape[0]:
  847. raise ValueError("mean_vector length {0} should match either one dimension of the square "
  848. "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
  849. zero_centered_img = np_img.reshape(1, -1) - mean_vector
  850. transformed_img = np.dot(zero_centered_img, transformation_matrix).reshape(np_img.shape)
  851. return transformed_img
  852. def random_affine(img, angle, translations, scale, shear, resample, fill_value=0):
  853. """
  854. Applies a random Affine transformation on the input PIL image.
  855. Args:
  856. img (PIL image): Image to be applied affine transformation.
  857. angle (Union[int, float]): Rotation angle in degrees, clockwise.
  858. translations (sequence): Translations in horizontal and vertical axis.
  859. scale (float): Scale parameter, a single number.
  860. shear (Union[float, sequence]): Shear amount parallel to X axis and Y axis.
  861. resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
  862. fill_value (Union[tuple int], optional): Optional fill_value to fill the area outside the transform
  863. in the output image. Used only in Pillow versions > 5.0.0.
  864. If None, no filling is performed.
  865. Returns:
  866. img (PIL image), Randomly affine transformed image.
  867. """
  868. if not is_pil(img):
  869. raise ValueError("Input image should be a Pillow image.")
  870. # rotation
  871. angle = random.uniform(angle[0], angle[1])
  872. # translation
  873. if translations is not None:
  874. max_dx = translations[0] * img.size[0]
  875. max_dy = translations[1] * img.size[1]
  876. translations = (np.round(random.uniform(-max_dx, max_dx)),
  877. np.round(random.uniform(-max_dy, max_dy)))
  878. else:
  879. translations = (0, 0)
  880. # scale
  881. if scale is not None:
  882. scale = random.uniform(scale[0], scale[1])
  883. else:
  884. scale = 1.0
  885. # shear
  886. if shear is not None:
  887. if len(shear) == 2:
  888. shear = [random.uniform(shear[0], shear[1]), 0.]
  889. elif len(shear) == 4:
  890. shear = [random.uniform(shear[0], shear[1]),
  891. random.uniform(shear[2], shear[3])]
  892. else:
  893. shear = 0.0
  894. output_size = img.size
  895. center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
  896. angle = math.radians(angle)
  897. if isinstance(shear, (tuple, list)) and len(shear) == 2:
  898. shear = [math.radians(s) for s in shear]
  899. elif isinstance(shear, numbers.Number):
  900. shear = math.radians(shear)
  901. shear = [shear, 0]
  902. else:
  903. raise ValueError(
  904. "Shear should be a single value or a tuple/list containing " +
  905. "two values. Got {}.".format(shear))
  906. scale = 1.0 / scale
  907. # Inverted rotation matrix with scale and shear
  908. d = math.cos(angle + shear[0]) * math.cos(angle + shear[1]) + \
  909. math.sin(angle + shear[0]) * math.sin(angle + shear[1])
  910. matrix = [
  911. math.cos(angle + shear[0]), math.sin(angle + shear[0]), 0,
  912. -math.sin(angle + shear[1]), math.cos(angle + shear[1]), 0
  913. ]
  914. matrix = [scale / d * m for m in matrix]
  915. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  916. matrix[2] += matrix[0] * (-center[0] - translations[0]) + matrix[1] * (-center[1] - translations[1])
  917. matrix[5] += matrix[3] * (-center[0] - translations[0]) + matrix[4] * (-center[1] - translations[1])
  918. # Apply center translation: C * RSS^-1 * C^-1 * T^-1
  919. matrix[2] += center[0]
  920. matrix[5] += center[1]
  921. if __version__ >= '5':
  922. kwargs = {"fillcolor": fill_value}
  923. else:
  924. kwargs = {}
  925. return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
  926. def mix_up_single(batch_size, img, label, alpha=0.2):
  927. """
  928. Apply mix up transformation to image and label in single batch internal, One hot encoding should done before this.
  929. Args:
  930. batch_size (int): the batch size of dataset.
  931. img (numpy.ndarray): NumPy image to be applied mix up transformation.
  932. label (numpy.ndarray): NumPy label to be applied mix up transformation.
  933. alpha (float): the mix up rate.
  934. Returns:
  935. mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
  936. mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
  937. """
  938. def cir_shift(data):
  939. index = list(range(1, batch_size)) + [0]
  940. data = data[index, ...]
  941. return data
  942. lam = np.random.beta(alpha, alpha, batch_size)
  943. lam_img = lam.reshape((batch_size, 1, 1, 1))
  944. mix_img = lam_img * img + (1 - lam_img) * cir_shift(img)
  945. lam_label = lam.reshape((batch_size, 1))
  946. mix_label = lam_label * label + (1 - lam_label) * cir_shift(label)
  947. return mix_img, mix_label
  948. def mix_up_muti(tmp, batch_size, img, label, alpha=0.2):
  949. """
  950. Apply mix up transformation to image and label in continuous batch, one hot encoding should done before this.
  951. Args:
  952. tmp (class object): mainly for saving the tmp parameter.
  953. batch_size (int): the batch size of dataset.
  954. img (numpy.ndarray): NumPy image to be applied mix up transformation.
  955. label (numpy.ndarray): NumPy label to be applied mix up transformation.
  956. alpha (float): refer to the mix up rate.
  957. Returns:
  958. mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
  959. mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
  960. """
  961. lam = np.random.beta(alpha, alpha, batch_size)
  962. if tmp.is_first:
  963. lam = np.ones(batch_size)
  964. tmp.is_first = False
  965. lam_img = lam.reshape((batch_size, 1, 1, 1))
  966. mix_img = lam_img * img + (1 - lam_img) * tmp.image
  967. lam_label = lam.reshape(batch_size, 1)
  968. mix_label = lam_label * label + (1 - lam_label) * tmp.label
  969. tmp.image = mix_img
  970. tmp.label = mix_label
  971. return mix_img, mix_label
  972. def rgb_to_hsv(np_rgb_img, is_hwc):
  973. """
  974. Convert RGB img to HSV img.
  975. Args:
  976. np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
  977. is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
  978. Returns:
  979. np_hsv_img (numpy.ndarray), NumPy HSV image with same type of np_rgb_img.
  980. """
  981. if is_hwc:
  982. r, g, b = np_rgb_img[:, :, 0], np_rgb_img[:, :, 1], np_rgb_img[:, :, 2]
  983. else:
  984. r, g, b = np_rgb_img[0, :, :], np_rgb_img[1, :, :], np_rgb_img[2, :, :]
  985. to_hsv = np.vectorize(colorsys.rgb_to_hsv)
  986. h, s, v = to_hsv(r, g, b)
  987. if is_hwc:
  988. axis = 2
  989. else:
  990. axis = 0
  991. np_hsv_img = np.stack((h, s, v), axis=axis)
  992. return np_hsv_img
  993. def rgb_to_hsvs(np_rgb_imgs, is_hwc):
  994. """
  995. Convert RGB imgs to HSV imgs.
  996. Args:
  997. np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
  998. or (C, H, W) or (N, C, H, W) to be converted.
  999. is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C);
  1000. If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W).
  1001. Returns:
  1002. np_hsv_imgs (numpy.ndarray), NumPy HSV images with same type of np_rgb_imgs.
  1003. """
  1004. if not is_numpy(np_rgb_imgs):
  1005. raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
  1006. shape_size = len(np_rgb_imgs.shape)
  1007. if not shape_size in (3, 4):
  1008. raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). \
  1009. Got {}.".format(np_rgb_imgs.shape))
  1010. if shape_size == 3:
  1011. batch_size = 0
  1012. if is_hwc:
  1013. num_channels = np_rgb_imgs.shape[2]
  1014. else:
  1015. num_channels = np_rgb_imgs.shape[0]
  1016. else:
  1017. batch_size = np_rgb_imgs.shape[0]
  1018. if is_hwc:
  1019. num_channels = np_rgb_imgs.shape[3]
  1020. else:
  1021. num_channels = np_rgb_imgs.shape[1]
  1022. if num_channels != 3:
  1023. raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
  1024. if batch_size == 0:
  1025. return rgb_to_hsv(np_rgb_imgs, is_hwc)
  1026. return np.array([rgb_to_hsv(img, is_hwc) for img in np_rgb_imgs])
  1027. def hsv_to_rgb(np_hsv_img, is_hwc):
  1028. """
  1029. Convert HSV img to RGB img.
  1030. Args:
  1031. np_hsv_img (numpy.ndarray): NumPy HSV image array of shape (H, W, C) or (C, H, W) to be converted.
  1032. is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
  1033. Returns:
  1034. np_rgb_img (numpy.ndarray), NumPy HSV image with same shape of np_hsv_img.
  1035. """
  1036. if is_hwc:
  1037. h, s, v = np_hsv_img[:, :, 0], np_hsv_img[:, :, 1], np_hsv_img[:, :, 2]
  1038. else:
  1039. h, s, v = np_hsv_img[0, :, :], np_hsv_img[1, :, :], np_hsv_img[2, :, :]
  1040. to_rgb = np.vectorize(colorsys.hsv_to_rgb)
  1041. r, g, b = to_rgb(h, s, v)
  1042. if is_hwc:
  1043. axis = 2
  1044. else:
  1045. axis = 0
  1046. np_rgb_img = np.stack((r, g, b), axis=axis)
  1047. return np_rgb_img
  1048. def hsv_to_rgbs(np_hsv_imgs, is_hwc):
  1049. """
  1050. Convert HSV imgs to RGB imgs.
  1051. Args:
  1052. np_hsv_imgs (numpy.ndarray): NumPy HSV images array of shape (H, W, C) or (N, H, W, C),
  1053. or (C, H, W) or (N, C, H, W) to be converted.
  1054. is_hwc (Bool): If True, the shape of np_hsv_imgs is (H, W, C) or (N, H, W, C);
  1055. If False, the shape of np_hsv_imgs is (C, H, W) or (N, C, H, W).
  1056. Returns:
  1057. np_rgb_imgs (numpy.ndarray), NumPy RGB images with same type of np_hsv_imgs.
  1058. """
  1059. if not is_numpy(np_hsv_imgs):
  1060. raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs)))
  1061. shape_size = len(np_hsv_imgs.shape)
  1062. if not shape_size in (3, 4):
  1063. raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). \
  1064. Got {}.".format(np_hsv_imgs.shape))
  1065. if shape_size == 3:
  1066. batch_size = 0
  1067. if is_hwc:
  1068. num_channels = np_hsv_imgs.shape[2]
  1069. else:
  1070. num_channels = np_hsv_imgs.shape[0]
  1071. else:
  1072. batch_size = np_hsv_imgs.shape[0]
  1073. if is_hwc:
  1074. num_channels = np_hsv_imgs.shape[3]
  1075. else:
  1076. num_channels = np_hsv_imgs.shape[1]
  1077. if num_channels != 3:
  1078. raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
  1079. if batch_size == 0:
  1080. return hsv_to_rgb(np_hsv_imgs, is_hwc)
  1081. return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs])
  1082. def random_color(img, degrees):
  1083. """
  1084. Adjust the color of the input PIL image by a random degree.
  1085. Args:
  1086. img (PIL image): Image to be color adjusted.
  1087. degrees (sequence): Range of random color adjustment degrees.
  1088. It should be in (min, max) format (default=(0.1,1.9)).
  1089. Returns:
  1090. img (PIL image), Color adjusted image.
  1091. """
  1092. if not is_pil(img):
  1093. raise TypeError("img should be PIL image. Got {}.".format(type(img)))
  1094. v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
  1095. return ImageEnhance.Color(img).enhance(v)
  1096. def random_sharpness(img, degrees):
  1097. """
  1098. Adjust the sharpness of the input PIL image by a random degree.
  1099. Args:
  1100. img (PIL image): Image to be sharpness adjusted.
  1101. degrees (sequence): Range of random sharpness adjustment degrees.
  1102. It should be in (min, max) format (default=(0.1,1.9)).
  1103. Returns:
  1104. img (PIL image), Sharpness adjusted image.
  1105. """
  1106. if not is_pil(img):
  1107. raise TypeError("img should be PIL image. Got {}.".format(type(img)))
  1108. v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
  1109. return ImageEnhance.Sharpness(img).enhance(v)
  1110. def auto_contrast(img, cutoff, ignore):
  1111. """
  1112. Automatically maximize the contrast of the input PIL image.
  1113. Args:
  1114. img (PIL image): Image to be augmented with AutoContrast.
  1115. cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
  1116. ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
  1117. Returns:
  1118. img (PIL image), Augmented image.
  1119. """
  1120. if not is_pil(img):
  1121. raise TypeError("img should be PIL image. Got {}.".format(type(img)))
  1122. return ImageOps.autocontrast(img, cutoff, ignore)
  1123. def invert_color(img):
  1124. """
  1125. Invert colors of input PIL image.
  1126. Args:
  1127. img (PIL image): Image to be color inverted.
  1128. Returns:
  1129. img (PIL image), Color inverted image.
  1130. """
  1131. if not is_pil(img):
  1132. raise TypeError("img should be PIL image. Got {}.".format(type(img)))
  1133. return ImageOps.invert(img)
  1134. def equalize(img):
  1135. """
  1136. Equalize the histogram of input PIL image.
  1137. Args:
  1138. img (PIL image): Image to be equalized
  1139. Returns:
  1140. img (PIL image), Equalized image.
  1141. """
  1142. if not is_pil(img):
  1143. raise TypeError("img should be PIL image. Got {}.".format(type(img)))
  1144. return ImageOps.equalize(img)
  1145. def uniform_augment(img, transforms, num_ops):
  1146. """
  1147. Uniformly select and apply a number of transforms sequentially from
  1148. a list of transforms. Randomly assigns a probability to each transform for
  1149. each image to decide whether apply it or not.
  1150. All the transforms in transform list must have the same input/output data type.
  1151. Args:
  1152. img: Image to be applied transformation.
  1153. transforms (list): List of transformations to be chosen from to apply.
  1154. num_ops (int): number of transforms to sequentially aaply.
  1155. Returns:
  1156. img, Transformed image.
  1157. """
  1158. op_idx = np.random.choice(len(transforms), size=num_ops, replace=False)
  1159. for idx in op_idx:
  1160. AugmentOp = transforms[idx]
  1161. pr = random.random()
  1162. if random.random() < pr:
  1163. img = AugmentOp(img.copy())
  1164. return img