You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

py_transforms_util.py 55 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Built-in py_transforms_utils functions.
  16. """
  17. import io
  18. import math
  19. import numbers
  20. import random
  21. import colorsys
  22. import numpy as np
  23. from PIL import Image, ImageOps, ImageEnhance, __version__
  24. from .utils import Inter
  25. from ..core.py_util_helpers import is_numpy
  26. augment_error_message = 'img should be PIL image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data.'
  27. def is_pil(img):
  28. """
  29. Check if the input image is PIL format.
  30. Args:
  31. img: Image to be checked.
  32. Returns:
  33. Bool, True if input is PIL image.
  34. """
  35. return isinstance(img, Image.Image)
  36. def normalize(img, mean, std):
  37. """
  38. Normalize the image between [0, 1] with respect to mean and standard deviation.
  39. Args:
  40. img (numpy.ndarray): Image array of shape CHW to be normalized.
  41. mean (list): List of mean values for each channel, w.r.t channel order.
  42. std (list): List of standard deviations for each channel, w.r.t. channel order.
  43. Returns:
  44. img (numpy.ndarray), Normalized image.
  45. """
  46. if not is_numpy(img):
  47. raise TypeError('img should be NumPy image. Got {}'.format(type(img)))
  48. num_channels = img.shape[0] # shape is (C, H, W)
  49. if len(mean) != len(std):
  50. raise ValueError("Length of mean and std must be equal")
  51. # if length equal to 1, adjust the mean and std arrays to have the correct
  52. # number of channels (replicate the values)
  53. if len(mean) == 1:
  54. mean = [mean[0]] * num_channels
  55. std = [std[0]] * num_channels
  56. elif len(mean) != num_channels:
  57. raise ValueError("Length of mean and std must both be 1 or equal to the number of channels({0})"
  58. .format(num_channels))
  59. mean = np.array(mean, dtype=img.dtype)
  60. std = np.array(std, dtype=img.dtype)
  61. return (img - mean[:, None, None]) / std[:, None, None]
  62. def decode(img):
  63. """
  64. Decode the input image to PIL image format in RGB mode.
  65. Args:
  66. img: Image to be decoded.
  67. Returns:
  68. img (PIL image), Decoded image in RGB mode.
  69. """
  70. try:
  71. data = io.BytesIO(img)
  72. img = Image.open(data)
  73. return img.convert('RGB')
  74. except IOError as e:
  75. raise ValueError("{0}\nWARNING: Failed to decode given image.".format(e))
  76. except AttributeError as e:
  77. raise ValueError("{0}\nWARNING: Failed to decode, Image might already be decoded.".format(e))
  78. def hwc_to_chw(img):
  79. """
  80. Transpose the input image; shape (H, W, C) to shape (C, H, W).
  81. Args:
  82. img (numpy.ndarray): Image to be converted.
  83. Returns:
  84. img (numpy.ndarray), Converted image.
  85. """
  86. if is_numpy(img):
  87. return img.transpose(2, 0, 1).copy()
  88. raise TypeError('img should be NumPy array. Got {}'.format(type(img)))
  89. def to_tensor(img, output_type):
  90. """
  91. Change the input image (PIL image or NumPy image array) to NumPy format.
  92. Args:
  93. img (Union[PIL image, numpy.ndarray]): Image to be converted.
  94. output_type: The datatype of the NumPy output. e.g. np.float32
  95. Returns:
  96. img (numpy.ndarray), Converted image.
  97. """
  98. if not (is_pil(img) or is_numpy(img)):
  99. raise TypeError('img should be PIL image or NumPy array. Got {}'.format(type(img)))
  100. img = np.asarray(img)
  101. if img.ndim not in (2, 3):
  102. raise ValueError('img dimension should be 2 or 3. Got {}'.format(img.ndim))
  103. if img.ndim == 2:
  104. img = img[:, :, None]
  105. img = hwc_to_chw(img)
  106. img = img / 255.
  107. return to_type(img, output_type)
  108. def to_pil(img):
  109. """
  110. Convert the input image to PIL format.
  111. Args:
  112. img: Image to be converted.
  113. Returns:
  114. img (PIL image), Converted image.
  115. """
  116. if not is_pil(img):
  117. return Image.fromarray(img)
  118. return img
  119. def horizontal_flip(img):
  120. """
  121. Flip the input image horizontally.
  122. Args:
  123. img (PIL image): Image to be flipped horizontally.
  124. Returns:
  125. img (PIL image), Horizontally flipped image.
  126. """
  127. if not is_pil(img):
  128. raise TypeError(augment_error_message.format(type(img)))
  129. return img.transpose(Image.FLIP_LEFT_RIGHT)
  130. def vertical_flip(img):
  131. """
  132. Flip the input image vertically.
  133. Args:
  134. img (PIL image): Image to be flipped vertically.
  135. Returns:
  136. img (PIL image), Vertically flipped image.
  137. """
  138. if not is_pil(img):
  139. raise TypeError(augment_error_message.format(type(img)))
  140. return img.transpose(Image.FLIP_TOP_BOTTOM)
  141. def random_horizontal_flip(img, prob):
  142. """
  143. Randomly flip the input image horizontally.
  144. Args:
  145. img (PIL image): Image to be flipped.
  146. If the given probability is above the random probability, then the image is flipped.
  147. prob (float): Probability of the image being flipped.
  148. Returns:
  149. img (PIL image), Converted image.
  150. """
  151. if not is_pil(img):
  152. raise TypeError(augment_error_message.format(type(img)))
  153. if prob > random.random():
  154. img = horizontal_flip(img)
  155. return img
  156. def random_vertical_flip(img, prob):
  157. """
  158. Randomly flip the input image vertically.
  159. Args:
  160. img (PIL image): Image to be flipped.
  161. If the given probability is above the random probability, then the image is flipped.
  162. prob (float): Probability of the image being flipped.
  163. Returns:
  164. img (PIL image), Converted image.
  165. """
  166. if not is_pil(img):
  167. raise TypeError(augment_error_message.format(type(img)))
  168. if prob > random.random():
  169. img = vertical_flip(img)
  170. return img
  171. def crop(img, top, left, height, width):
  172. """
  173. Crop the input PIL image.
  174. Args:
  175. img (PIL image): Image to be cropped. (0,0) denotes the top left corner of the image,
  176. in the directions of (width, height).
  177. top (int): Vertical component of the top left corner of the crop box.
  178. left (int): Horizontal component of the top left corner of the crop box.
  179. height (int): Height of the crop box.
  180. width (int): Width of the crop box.
  181. Returns:
  182. img (PIL image), Cropped image.
  183. """
  184. if not is_pil(img):
  185. raise TypeError(augment_error_message.format(type(img)))
  186. return img.crop((left, top, left + width, top + height))
  187. def resize(img, size, interpolation=Inter.BILINEAR):
  188. """
  189. Resize the input PIL image to desired size.
  190. Args:
  191. img (PIL image): Image to be resized.
  192. size (Union[int, sequence]): The output size of the resized image.
  193. If size is an integer, smaller edge of the image will be resized to this value with
  194. the same image aspect ratio.
  195. If size is a sequence of (height, width), this will be the desired output size.
  196. interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2.
  197. Returns:
  198. img (PIL image), Resized image.
  199. """
  200. if not is_pil(img):
  201. raise TypeError(augment_error_message.format(type(img)))
  202. if not (isinstance(size, int) or (isinstance(size, (list, tuple)) and len(size) == 2)):
  203. raise TypeError('Size should be a single number or a list/tuple (h, w) of length 2.'
  204. 'Got {}'.format(size))
  205. if isinstance(size, int):
  206. img_width, img_height = img.size
  207. aspect_ratio = img_width / img_height # maintain the aspect ratio
  208. if (img_width <= img_height and img_width == size) or \
  209. (img_height <= img_width and img_height == size):
  210. return img
  211. if img_width < img_height:
  212. out_width = size
  213. out_height = int(size / aspect_ratio)
  214. return img.resize((out_width, out_height), interpolation)
  215. out_height = size
  216. out_width = int(size * aspect_ratio)
  217. return img.resize((out_width, out_height), interpolation)
  218. return img.resize(size[::-1], interpolation)
  219. def center_crop(img, size):
  220. """
  221. Crop the input PIL image at the center to the given size.
  222. Args:
  223. img (PIL image): Image to be cropped.
  224. size (Union[int, tuple]): The size of the crop box.
  225. If size is an integer, a square crop of size (size, size) is returned.
  226. If size is a sequence of length 2, it should be (height, width).
  227. Returns:
  228. img (PIL image), Cropped image.
  229. """
  230. if not is_pil(img):
  231. raise TypeError(augment_error_message.format(type(img)))
  232. if isinstance(size, int):
  233. size = (size, size)
  234. img_width, img_height = img.size
  235. crop_height, crop_width = size
  236. crop_top = int(round((img_height - crop_height) / 2.))
  237. crop_left = int(round((img_width - crop_width) / 2.))
  238. return crop(img, crop_top, crop_left, crop_height, crop_width)
  239. def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, max_attempts=10):
  240. """
  241. Crop the input PIL image to a random size and aspect ratio.
  242. Args:
  243. img (PIL image): Image to be randomly cropped and resized.
  244. size (Union[int, sequence]): The size of the output image.
  245. If size is an integer, a square crop of size (size, size) is returned.
  246. If size is a sequence of length 2, it should be (height, width).
  247. scale (tuple): Range (min, max) of respective size of the original size to be cropped.
  248. ratio (tuple): Range (min, max) of aspect ratio to be cropped.
  249. interpolation (interpolation mode): Image interpolation mode. Default is Inter.BILINEAR = 2.
  250. max_attempts (int): The maximum number of attempts to propose a valid crop_area. Default 10.
  251. If exceeded, fall back to use center_crop instead.
  252. Returns:
  253. img (PIL image), Randomly cropped and resized image.
  254. """
  255. if not is_pil(img):
  256. raise TypeError(augment_error_message.format(type(img)))
  257. if isinstance(size, int):
  258. size = (size, size)
  259. elif isinstance(size, (tuple, list)) and len(size) == 2:
  260. size = size
  261. else:
  262. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  263. if scale[0] > scale[1] or ratio[0] > ratio[1]:
  264. raise ValueError("Range should be in the order of (min, max).")
  265. def _input_to_factor(img, scale, ratio):
  266. img_width, img_height = img.size
  267. img_area = img_width * img_height
  268. for _ in range(max_attempts):
  269. crop_area = random.uniform(scale[0], scale[1]) * img_area
  270. # in case of non-symmetrical aspect ratios,
  271. # use uniform distribution on a logarithmic scale.
  272. log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
  273. aspect_ratio = math.exp(random.uniform(*log_ratio))
  274. width = int(round(math.sqrt(crop_area * aspect_ratio)))
  275. height = int(round(width / aspect_ratio))
  276. if 0 < width <= img_width and 0 < height <= img_height:
  277. top = random.randint(0, img_height - height)
  278. left = random.randint(0, img_width - width)
  279. return top, left, height, width
  280. # exceeding max_attempts, use center crop
  281. img_ratio = img_width / img_height
  282. if img_ratio < ratio[0]:
  283. width = img_width
  284. height = int(round(width / ratio[0]))
  285. elif img_ratio > ratio[1]:
  286. height = img_height
  287. width = int(round(height * ratio[1]))
  288. else:
  289. width = img_width
  290. height = img_height
  291. top = int(round((img_height - height) / 2.))
  292. left = int(round((img_width - width) / 2.))
  293. return top, left, height, width
  294. top, left, height, width = _input_to_factor(img, scale, ratio)
  295. img = crop(img, top, left, height, width)
  296. img = resize(img, size, interpolation)
  297. return img
  298. def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode):
  299. """
  300. Crop the input PIL image at a random location.
  301. Args:
  302. img (PIL image): Image to be randomly cropped.
  303. size (Union[int, sequence]): The output size of the cropped image.
  304. If size is an integer, a square crop of size (size, size) is returned.
  305. If size is a sequence of length 2, it should be (height, width).
  306. padding (Union[int, sequence], optional): The number of pixels to pad the image.
  307. If a single number is provided, it pads all borders with this value.
  308. If a tuple or list of 2 values are provided, it pads the (left and top)
  309. with the first value and (right and bottom) with the second value.
  310. If 4 values are provided as a list or tuple,
  311. it pads the left, top, right and bottom respectively.
  312. Default is None.
  313. pad_if_needed (bool): Pad the image if either side is smaller than
  314. the given output size. Default is False.
  315. fill_value (Union[int, tuple]): The pixel intensity of the borders if
  316. the padding_mode is 'constant'. If it is a 3-tuple, it is used to
  317. fill R, G, B channels respectively.
  318. padding_mode (str): The method of padding. Can be any of
  319. ['constant', 'edge', 'reflect', 'symmetric'].
  320. - 'constant', means it fills the border with constant values
  321. - 'edge', means it pads with the last value on the edge
  322. - 'reflect', means it reflects the values on the edge omitting the last
  323. value of edge
  324. - 'symmetric', means it reflects the values on the edge repeating the last
  325. value of edge
  326. Returns:
  327. PIL image, Cropped image.
  328. """
  329. if not is_pil(img):
  330. raise TypeError(augment_error_message.format(type(img)))
  331. if isinstance(size, int):
  332. size = (size, size)
  333. elif isinstance(size, (tuple, list)) and len(size) == 2:
  334. size = size
  335. else:
  336. raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
  337. def _input_to_factor(img, size):
  338. img_width, img_height = img.size
  339. height, width = size
  340. if height > img_height or width > img_width:
  341. raise ValueError("Crop size {} is larger than input image size {}".format(size, (img_height, img_width)))
  342. if width == img_width and height == img_height:
  343. return 0, 0, img_height, img_width
  344. top = random.randint(0, img_height - height)
  345. left = random.randint(0, img_width - width)
  346. return top, left, height, width
  347. if padding is not None:
  348. img = pad(img, padding, fill_value, padding_mode)
  349. # pad width when needed, img.size (width, height), crop size (height, width)
  350. if pad_if_needed and img.size[0] < size[1]:
  351. img = pad(img, (size[1] - img.size[0], 0), fill_value, padding_mode)
  352. # pad height when needed
  353. if pad_if_needed and img.size[1] < size[0]:
  354. img = pad(img, (0, size[0] - img.size[1]), fill_value, padding_mode)
  355. top, left, height, width = _input_to_factor(img, size)
  356. return crop(img, top, left, height, width)
  357. def adjust_brightness(img, brightness_factor):
  358. """
  359. Adjust brightness of an image.
  360. Args:
  361. img (PIL image): Image to be adjusted.
  362. brightness_factor (float): A non negative number indicated the factor by which
  363. the brightness is adjusted. 0 gives a black image, 1 gives the original.
  364. Returns:
  365. img (PIL image), Brightness adjusted image.
  366. """
  367. if not is_pil(img):
  368. raise TypeError(augment_error_message.format(type(img)))
  369. enhancer = ImageEnhance.Brightness(img)
  370. img = enhancer.enhance(brightness_factor)
  371. return img
  372. def adjust_contrast(img, contrast_factor):
  373. """
  374. Adjust contrast of an image.
  375. Args:
  376. img (PIL image): PIL image to be adjusted.
  377. contrast_factor (float): A non negative number indicated the factor by which
  378. the contrast is adjusted. 0 gives a solid gray image, 1 gives the original.
  379. Returns:
  380. img (PIL image), Contrast adjusted image.
  381. """
  382. if not is_pil(img):
  383. raise TypeError(augment_error_message.format(type(img)))
  384. enhancer = ImageEnhance.Contrast(img)
  385. img = enhancer.enhance(contrast_factor)
  386. return img
  387. def adjust_saturation(img, saturation_factor):
  388. """
  389. Adjust saturation of an image.
  390. Args:
  391. img (PIL image): PIL image to be adjusted.
  392. saturation_factor (float): A non negative number indicated the factor by which
  393. the saturation is adjusted. 0 will give a black and white image, 1 will
  394. give the original.
  395. Returns:
  396. img (PIL image), Saturation adjusted image.
  397. """
  398. if not is_pil(img):
  399. raise TypeError(augment_error_message.format(type(img)))
  400. enhancer = ImageEnhance.Color(img)
  401. img = enhancer.enhance(saturation_factor)
  402. return img
  403. def adjust_hue(img, hue_factor):
  404. """
  405. Adjust hue of an image. The Hue is changed by changing the HSV values after image is converted to HSV.
  406. Args:
  407. img (PIL image): PIL image to be adjusted.
  408. hue_factor (float): Amount to shift the Hue channel. Value should be in
  409. [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel. This
  410. is because Hue wraps around when rotated 360 degrees.
  411. 0 means no shift that gives the original image while both -0.5 and 0.5
  412. will give an image with complementary colors .
  413. Returns:
  414. img (PIL image), Hue adjusted image.
  415. """
  416. image = img
  417. image_hue_factor = hue_factor
  418. if not -0.5 <= image_hue_factor <= 0.5:
  419. raise ValueError('image_hue_factor {} is not in [-0.5, 0.5].'.format(image_hue_factor))
  420. if not is_pil(image):
  421. raise TypeError(augment_error_message.format(type(image)))
  422. mode = image.mode
  423. if mode in {'L', '1', 'I', 'F'}:
  424. return image
  425. hue, saturation, value = img.convert('HSV').split()
  426. np_hue = np.array(hue, dtype=np.uint8)
  427. with np.errstate(over='ignore'):
  428. np_hue += np.uint8(image_hue_factor * 255)
  429. hue = Image.fromarray(np_hue, 'L')
  430. image = Image.merge('HSV', (hue, saturation, value)).convert(mode)
  431. return image
  432. def to_type(img, output_type):
  433. """
  434. Convert the NumPy image array to desired NumPy dtype.
  435. Args:
  436. img (numpy): NumPy image to cast to desired NumPy dtype.
  437. output_type (Numpy datatype): NumPy dtype to cast to.
  438. Returns:
  439. img (numpy.ndarray), Converted image.
  440. """
  441. if not is_numpy(img):
  442. raise TypeError('img should be NumPy image. Got {}'.format(type(img)))
  443. return img.astype(output_type)
  444. def rotate(img, angle, resample, expand, center, fill_value):
  445. """
  446. Rotate the input PIL image by angle.
  447. Args:
  448. img (PIL image): Image to be rotated.
  449. angle (int or float): Rotation angle in degrees, counter-clockwise.
  450. resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
  451. If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
  452. expand (bool, optional): Optional expansion flag. If set to True, expand the output
  453. image to make it large enough to hold the entire rotated image.
  454. If set to False or omitted, make the output image the same size as the input.
  455. Note that the expand flag assumes rotation around the center and no translation.
  456. center (tuple, optional): Optional center of rotation (a 2-tuple).
  457. Origin is the top left corner.
  458. fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
  459. If it is a 3-tuple, it is used for R, G, B channels respectively.
  460. If it is an integer, it is used for all RGB channels.
  461. Returns:
  462. img (PIL image), Rotated image.
  463. https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
  464. """
  465. if not is_pil(img):
  466. raise TypeError(augment_error_message.format(type(img)))
  467. if isinstance(fill_value, int):
  468. fill_value = tuple([fill_value] * 3)
  469. return img.rotate(angle, resample, expand, center, fillcolor=fill_value)
  470. def random_color_adjust(img, brightness, contrast, saturation, hue):
  471. """
  472. Randomly adjust the brightness, contrast, saturation, and hue of an image.
  473. Args:
  474. img (PIL image): Image to have its color adjusted randomly.
  475. brightness (Union[float, tuple]): Brightness adjustment factor. Cannot be negative.
  476. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness].
  477. If it is a sequence, it should be [min, max] for the range.
  478. contrast (Union[float, tuple]): Contrast adjustment factor. Cannot be negative.
  479. If it is a float, the factor is uniformly chosen from the range [max(0, 1-contrast), 1+contrast].
  480. If it is a sequence, it should be [min, max] for the range.
  481. saturation (Union[float, tuple]): Saturation adjustment factor. Cannot be negative.
  482. If it is a float, the factor is uniformly chosen from the range [max(0, 1-saturation), 1+saturation].
  483. If it is a sequence, it should be [min, max] for the range.
  484. hue (Union[float, tuple]): Hue adjustment factor.
  485. If it is a float, the range will be [-hue, hue]. Value should be 0 <= hue <= 0.5.
  486. If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5.
  487. Returns:
  488. img (PIL image), Image after random adjustment of its color.
  489. """
  490. if not is_pil(img):
  491. raise TypeError(augment_error_message.format(type(img)))
  492. def _input_to_factor(value, input_name, center=1, bound=(0, float('inf')), non_negative=True):
  493. if isinstance(value, numbers.Number):
  494. if value < 0:
  495. raise ValueError("The input value of {} cannot be negative.".format(input_name))
  496. # convert value into a range
  497. value = [center - value, center + value]
  498. if non_negative:
  499. value[0] = max(0, value[0])
  500. elif isinstance(value, (list, tuple)) and len(value) == 2:
  501. if not bound[0] <= value[0] <= value[1] <= bound[1]:
  502. raise ValueError("Please check your value range of {} is valid and "
  503. "within the bound {}".format(input_name, bound))
  504. else:
  505. raise TypeError("Input of {} should be either a single value, or a list/tuple of "
  506. "length 2.".format(input_name))
  507. factor = random.uniform(value[0], value[1])
  508. return factor
  509. brightness_factor = _input_to_factor(brightness, 'brightness')
  510. contrast_factor = _input_to_factor(contrast, 'contrast')
  511. saturation_factor = _input_to_factor(saturation, 'saturation')
  512. hue_factor = _input_to_factor(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  513. transforms = []
  514. transforms.append(lambda img: adjust_brightness(img, brightness_factor))
  515. transforms.append(lambda img: adjust_contrast(img, contrast_factor))
  516. transforms.append(lambda img: adjust_saturation(img, saturation_factor))
  517. transforms.append(lambda img: adjust_hue(img, hue_factor))
  518. # apply color adjustments in a random order
  519. random.shuffle(transforms)
  520. for transform in transforms:
  521. img = transform(img)
  522. return img
  523. def random_rotation(img, degrees, resample, expand, center, fill_value):
  524. """
  525. Rotate the input PIL image by a random angle.
  526. Args:
  527. img (PIL image): Image to be rotated.
  528. degrees (Union[int, float, sequence]): Range of random rotation degrees.
  529. If degrees is a number, the range will be converted to (-degrees, degrees).
  530. If degrees is a sequence, it should be (min, max).
  531. resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
  532. If omitted, or if the image has mode "1" or "P", it is set to be Inter.NEAREST.
  533. expand (bool, optional): Optional expansion flag. If set to True, expand the output
  534. image to make it large enough to hold the entire rotated image.
  535. If set to False or omitted, make the output image the same size as the input.
  536. Note that the expand flag assumes rotation around the center and no translation.
  537. center (tuple, optional): Optional center of rotation (a 2-tuple).
  538. Origin is the top left corner.
  539. fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
  540. If it is a 3-tuple, it is used for R, G, B channels respectively.
  541. If it is an integer, it is used for all RGB channels.
  542. Returns:
  543. img (PIL image), Rotated image.
  544. https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
  545. """
  546. if not is_pil(img):
  547. raise TypeError(augment_error_message.format(type(img)))
  548. if isinstance(degrees, numbers.Number):
  549. if degrees < 0:
  550. raise ValueError("If degrees is a single number, it cannot be negative.")
  551. degrees = (-degrees, degrees)
  552. elif isinstance(degrees, (list, tuple)):
  553. if len(degrees) != 2:
  554. raise ValueError("If degrees is a sequence, the length must be 2.")
  555. else:
  556. raise TypeError("Degrees must be a single non-negative number or a sequence")
  557. angle = random.uniform(degrees[0], degrees[1])
  558. return rotate(img, angle, resample, expand, center, fill_value)
  559. def five_crop(img, size):
  560. """
  561. Generate 5 cropped images (one central and four corners).
  562. Args:
  563. img (PIL image): PIL image to be cropped.
  564. size (Union[int, sequence]): The output size of the crop.
  565. If size is an integer, a square crop of size (size, size) is returned.
  566. If size is a sequence of length 2, it should be (height, width).
  567. Returns:
  568. img_tuple (tuple), a tuple of 5 PIL images
  569. (top_left, top_right, bottom_left, bottom_right, center).
  570. """
  571. if not is_pil(img):
  572. raise TypeError(augment_error_message.format(type(img)))
  573. if isinstance(size, int):
  574. size = (size, size)
  575. elif isinstance(size, (tuple, list)) and len(size) == 2:
  576. size = size
  577. else:
  578. raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
  579. # PIL image.size returns in (width, height) order
  580. img_width, img_height = img.size
  581. crop_height, crop_width = size
  582. if crop_height > img_height or crop_width > img_width:
  583. raise ValueError("Crop size {} is larger than input image size {}".format(size, (img_height, img_width)))
  584. center = center_crop(img, (crop_height, crop_width))
  585. top_left = img.crop((0, 0, crop_width, crop_height))
  586. top_right = img.crop((img_width - crop_width, 0, img_width, crop_height))
  587. bottom_left = img.crop((0, img_height - crop_height, crop_width, img_height))
  588. bottom_right = img.crop((img_width - crop_width, img_height - crop_height, img_width, img_height))
  589. return top_left, top_right, bottom_left, bottom_right, center
  590. def ten_crop(img, size, use_vertical_flip=False):
  591. """
  592. Generate 10 cropped images (first 5 from FiveCrop, second 5 from their flipped version).
  593. The default is horizontal flipping, use_vertical_flip=False.
  594. Args:
  595. img (PIL image): PIL image to be cropped.
  596. size (Union[int, sequence]): The output size of the crop.
  597. If size is an integer, a square crop of size (size, size) is returned.
  598. If size is a sequence of length 2, it should be (height, width).
  599. use_vertical_flip (bool): Flip the image vertically instead of horizontally if set to True.
  600. Returns:
  601. img_tuple (tuple), a tuple of 10 PIL images
  602. (top_left, top_right, bottom_left, bottom_right, center) of original image +
  603. (top_left, top_right, bottom_left, bottom_right, center) of flipped image.
  604. """
  605. if not is_pil(img):
  606. raise TypeError(augment_error_message.format(type(img)))
  607. if isinstance(size, int):
  608. size = (size, size)
  609. elif isinstance(size, (tuple, list)) and len(size) == 2:
  610. size = size
  611. else:
  612. raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
  613. first_five_crop = five_crop(img, size)
  614. if use_vertical_flip:
  615. img = vertical_flip(img)
  616. else:
  617. img = horizontal_flip(img)
  618. second_five_crop = five_crop(img, size)
  619. return first_five_crop + second_five_crop
  620. def grayscale(img, num_output_channels):
  621. """
  622. Convert the input PIL image to grayscale image.
  623. Args:
  624. img (PIL image): PIL image to be converted to grayscale.
  625. num_output_channels (int): Number of channels of the output grayscale image (1 or 3).
  626. Returns:
  627. img (PIL image), grayscaled image.
  628. """
  629. if not is_pil(img):
  630. raise TypeError(augment_error_message.format(type(img)))
  631. if num_output_channels == 1:
  632. img = img.convert('L')
  633. elif num_output_channels == 3:
  634. # each channel is the same grayscale layer
  635. img = img.convert('L')
  636. np_gray = np.array(img, dtype=np.uint8)
  637. np_img = np.dstack([np_gray, np_gray, np_gray])
  638. img = Image.fromarray(np_img, 'RGB')
  639. else:
  640. raise ValueError('num_output_channels should be either 1 or 3. Got {}'.format(num_output_channels))
  641. return img
  642. def pad(img, padding, fill_value, padding_mode):
  643. """
  644. Pads the image according to padding parameters.
  645. Args:
  646. img (PIL image): Image to be padded.
  647. padding (Union[int, sequence], optional): The number of pixels to pad the image.
  648. If a single number is provided, it pads all borders with this value.
  649. If a tuple or list of 2 values are provided, it pads the (left and top)
  650. with the first value and (right and bottom) with the second value.
  651. If 4 values are provided as a list or tuple,
  652. it pads the left, top, right and bottom respectively.
  653. Default is None.
  654. fill_value (Union[int, tuple]): The pixel intensity of the borders if
  655. the padding_mode is "constant". If it is a 3-tuple, it is used to
  656. fill R, G, B channels respectively.
  657. padding_mode (str): The method of padding. Can be any of
  658. ['constant', 'edge', 'reflect', 'symmetric'].
  659. - 'constant', means it fills the border with constant values
  660. - 'edge', means it pads with the last value on the edge
  661. - 'reflect', means it reflects the values on the edge omitting the last
  662. value of edge
  663. - 'symmetric', means it reflects the values on the edge repeating the last
  664. value of edge
  665. Returns:
  666. img (PIL image), Padded image.
  667. """
  668. if not is_pil(img):
  669. raise TypeError(augment_error_message.format(type(img)))
  670. if isinstance(padding, numbers.Number):
  671. top = bottom = left = right = padding
  672. elif isinstance(padding, (tuple, list)):
  673. if len(padding) == 2:
  674. left = right = padding[0]
  675. top = bottom = padding[1]
  676. elif len(padding) == 4:
  677. left = padding[0]
  678. top = padding[1]
  679. right = padding[2]
  680. bottom = padding[3]
  681. else:
  682. raise ValueError("The size of the padding list or tuple should be 2 or 4.")
  683. else:
  684. raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.")
  685. if not isinstance(fill_value, (numbers.Number, str, tuple)):
  686. raise TypeError("fill_value can be any of: an integer, a string or a tuple.")
  687. if padding_mode not in ['constant', 'edge', 'reflect', 'symmetric']:
  688. raise ValueError("Padding mode can be any of ['constant', 'edge', 'reflect', 'symmetric'].")
  689. if padding_mode == 'constant':
  690. if img.mode == 'P':
  691. palette = img.getpalette()
  692. image = ImageOps.expand(img, border=padding, fill=fill_value)
  693. image.putpalette(palette)
  694. return image
  695. return ImageOps.expand(img, border=padding, fill=fill_value)
  696. if img.mode == 'P':
  697. palette = img.getpalette()
  698. img = np.asarray(img)
  699. img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
  700. img = Image.fromarray(img)
  701. img.putpalette(palette)
  702. return img
  703. img = np.asarray(img)
  704. if len(img.shape) == 3:
  705. img = np.pad(img, ((top, bottom), (left, right), (0, 0)), padding_mode)
  706. if len(img.shape) == 2:
  707. img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
  708. return Image.fromarray(img)
  709. def get_perspective_params(img, distortion_scale):
  710. """Helper function to get parameters for RandomPerspective.
  711. """
  712. img_width, img_height = img.size
  713. distorted_half_width = int(img_width / 2 * distortion_scale)
  714. distorted_half_height = int(img_height / 2 * distortion_scale)
  715. top_left = (random.randint(0, distorted_half_width),
  716. random.randint(0, distorted_half_height))
  717. top_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
  718. random.randint(0, distorted_half_height))
  719. bottom_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
  720. random.randint(img_height - distorted_half_height - 1, img_height - 1))
  721. bottom_left = (random.randint(0, distorted_half_width),
  722. random.randint(img_height - distorted_half_height - 1, img_height - 1))
  723. start_points = [(0, 0), (img_width - 1, 0), (img_width - 1, img_height - 1), (0, img_height - 1)]
  724. end_points = [top_left, top_right, bottom_right, bottom_left]
  725. return start_points, end_points
  726. def perspective(img, start_points, end_points, interpolation=Inter.BICUBIC):
  727. """
  728. Apply perspective transformation to the input PIL image.
  729. Args:
  730. img (PIL image): PIL image to be applied perspective transformation.
  731. start_points (list): List of [top_left, top_right, bottom_right, bottom_left] of the original image.
  732. end_points: List of [top_left, top_right, bottom_right, bottom_left] of the transformed image.
  733. interpolation (interpolation mode): Image interpolation mode, Default is Inter.BICUBIC = 3.
  734. Returns:
  735. img (PIL image), Image after being perspectively transformed.
  736. """
  737. def _input_to_coeffs(original_points, transformed_points):
  738. # Get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
  739. # According to "Using Projective Geometry to Correct a Camera" from AMS.
  740. # http://www.ams.org/publicoutreach/feature-column/fc-2013-03
  741. # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Geometry.c#L377
  742. matrix = []
  743. for pt1, pt2 in zip(transformed_points, original_points):
  744. matrix.append([pt1[0], pt1[1], 1, 0, 0, 0, -pt2[0] * pt1[0], -pt2[0] * pt1[1]])
  745. matrix.append([0, 0, 0, pt1[0], pt1[1], 1, -pt2[1] * pt1[0], -pt2[1] * pt1[1]])
  746. matrix_a = np.array(matrix, dtype=np.float)
  747. matrix_b = np.array(original_points, dtype=np.float).reshape(8)
  748. res = np.linalg.lstsq(matrix_a, matrix_b, rcond=None)[0]
  749. return res.tolist()
  750. if not is_pil(img):
  751. raise TypeError(augment_error_message.format(type(img)))
  752. coeffs = _input_to_coeffs(start_points, end_points)
  753. return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
  754. def get_erase_params(np_img, scale, ratio, value, bounded, max_attempts):
  755. """Helper function to get parameters for RandomErasing/ Cutout.
  756. """
  757. if not is_numpy(np_img):
  758. raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
  759. image_c, image_h, image_w = np_img.shape
  760. area = image_h * image_w
  761. for _ in range(max_attempts):
  762. erase_area = random.uniform(scale[0], scale[1]) * area
  763. aspect_ratio = random.uniform(ratio[0], ratio[1])
  764. erase_w = int(round(math.sqrt(erase_area * aspect_ratio)))
  765. erase_h = int(round(erase_w / aspect_ratio))
  766. erase_shape = (image_c, erase_h, erase_w)
  767. if erase_h < image_h and erase_w < image_w:
  768. if bounded:
  769. i = random.randint(0, image_h - erase_h)
  770. j = random.randint(0, image_w - erase_w)
  771. else:
  772. def clip(x, lower, upper):
  773. return max(lower, min(x, upper))
  774. x = random.randint(0, image_w)
  775. y = random.randint(0, image_h)
  776. x1 = clip(x - erase_w // 2, 0, image_w)
  777. x2 = clip(x + erase_w // 2, 0, image_w)
  778. y1 = clip(y - erase_h // 2, 0, image_h)
  779. y2 = clip(y + erase_h // 2, 0, image_h)
  780. i, j, erase_h, erase_w = y1, x1, y2 - y1, x2 - x1
  781. if isinstance(value, numbers.Number):
  782. erase_value = value
  783. elif isinstance(value, (str, bytes)):
  784. erase_value = np.random.normal(loc=0.0, scale=1.0, size=erase_shape)
  785. elif isinstance(value, (tuple, list)) and len(value) == 3:
  786. value = np.array(value)
  787. erase_value = np.multiply(np.ones(erase_shape), value[:, None, None])
  788. else:
  789. raise ValueError("The value for erasing should be either a single value, or a string "
  790. "'random', or a sequence of 3 elements for RGB respectively.")
  791. return i, j, erase_h, erase_w, erase_value
  792. # exceeding max_attempts, return original image
  793. return 0, 0, image_h, image_w, np_img
  794. def erase(np_img, i, j, height, width, erase_value, inplace=False):
  795. """
  796. Erase the pixels, within a selected rectangle region, to the given value. Applied on the input NumPy image array.
  797. Args:
  798. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be erased.
  799. i (int): The height component of the top left corner (height, width).
  800. j (int): The width component of the top left corner (height, width).
  801. height (int): Height of the erased region.
  802. width (int): Width of the erased region.
  803. erase_value: Erase value return from helper function get_erase_params().
  804. inplace (bool, optional): Apply this transform inplace. Default is False.
  805. Returns:
  806. np_img (numpy.ndarray), Erased NumPy image array.
  807. """
  808. if not is_numpy(np_img):
  809. raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
  810. if not inplace:
  811. np_img = np_img.copy()
  812. # (i, j) here are the coordinates of axes (height, width) as in CHW
  813. np_img[:, i:i + height, j:j + width] = erase_value
  814. return np_img
  815. def linear_transform(np_img, transformation_matrix, mean_vector):
  816. """
  817. Apply linear transformation to the input NumPy image array, given a square transformation matrix and a mean_vector.
  818. The transformation first flattens the input array and subtract mean_vector from it, then computes the
  819. dot product with the transformation matrix, and reshapes it back to its original shape.
  820. Args:
  821. np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be linear transformed.
  822. transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
  823. mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
  824. Returns:
  825. np_img (numpy.ndarray), Linear transformed image.
  826. """
  827. if not is_numpy(np_img):
  828. raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
  829. if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
  830. raise ValueError("transformation_matrix should be a square matrix. "
  831. "Got shape {} instead".format(transformation_matrix.shape))
  832. if np.prod(np_img.shape) != transformation_matrix.shape[0]:
  833. raise ValueError("transformation_matrix shape {0} not compatible with "
  834. "Numpy image shape {1}.".format(transformation_matrix.shape, np_img.shape))
  835. if mean_vector.shape[0] != transformation_matrix.shape[0]:
  836. raise ValueError("mean_vector length {0} should match either one dimension of the square "
  837. "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
  838. zero_centered_img = np_img.reshape(1, -1) - mean_vector
  839. transformed_img = np.dot(zero_centered_img, transformation_matrix).reshape(np_img.shape)
  840. return transformed_img
  841. def random_affine(img, angle, translations, scale, shear, resample, fill_value=0):
  842. """
  843. Applies a random Affine transformation on the input PIL image.
  844. Args:
  845. img (PIL image): Image to be applied affine transformation.
  846. angle (Union[int, float]): Rotation angle in degrees, clockwise.
  847. translations (sequence): Translations in horizontal and vertical axis.
  848. scale (float): Scale parameter, a single number.
  849. shear (Union[float, sequence]): Shear amount parallel to X axis and Y axis.
  850. resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
  851. fill_value (Union[tuple int], optional): Optional fill_value to fill the area outside the transform
  852. in the output image. Used only in Pillow versions > 5.0.0.
  853. If None, no filling is performed.
  854. Returns:
  855. img (PIL image), Randomly affine transformed image.
  856. """
  857. if not is_pil(img):
  858. raise ValueError("Input image should be a Pillow image.")
  859. # rotation
  860. angle = random.uniform(angle[0], angle[1])
  861. # translation
  862. if translations is not None:
  863. max_dx = translations[0] * img.size[0]
  864. max_dy = translations[1] * img.size[1]
  865. translations = (np.round(random.uniform(-max_dx, max_dx)),
  866. np.round(random.uniform(-max_dy, max_dy)))
  867. else:
  868. translations = (0, 0)
  869. # scale
  870. if scale is not None:
  871. scale = random.uniform(scale[0], scale[1])
  872. else:
  873. scale = 1.0
  874. # shear
  875. if shear is not None:
  876. if len(shear) == 2:
  877. shear = [random.uniform(shear[0], shear[1]), 0.]
  878. elif len(shear) == 4:
  879. shear = [random.uniform(shear[0], shear[1]),
  880. random.uniform(shear[2], shear[3])]
  881. else:
  882. shear = 0.0
  883. output_size = img.size
  884. center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
  885. angle = math.radians(angle)
  886. if isinstance(shear, (tuple, list)) and len(shear) == 2:
  887. shear = [math.radians(s) for s in shear]
  888. elif isinstance(shear, numbers.Number):
  889. shear = math.radians(shear)
  890. shear = [shear, 0]
  891. else:
  892. raise ValueError(
  893. "Shear should be a single value or a tuple/list containing " +
  894. "two values. Got {}".format(shear))
  895. scale = 1.0 / scale
  896. # Inverted rotation matrix with scale and shear
  897. d = math.cos(angle + shear[0]) * math.cos(angle + shear[1]) + \
  898. math.sin(angle + shear[0]) * math.sin(angle + shear[1])
  899. matrix = [
  900. math.cos(angle + shear[0]), math.sin(angle + shear[0]), 0,
  901. -math.sin(angle + shear[1]), math.cos(angle + shear[1]), 0
  902. ]
  903. matrix = [scale / d * m for m in matrix]
  904. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  905. matrix[2] += matrix[0] * (-center[0] - translations[0]) + matrix[1] * (-center[1] - translations[1])
  906. matrix[5] += matrix[3] * (-center[0] - translations[0]) + matrix[4] * (-center[1] - translations[1])
  907. # Apply center translation: C * RSS^-1 * C^-1 * T^-1
  908. matrix[2] += center[0]
  909. matrix[5] += center[1]
  910. if __version__ >= '5':
  911. kwargs = {"fillcolor": fill_value}
  912. else:
  913. kwargs = {}
  914. return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs)
  915. def mix_up_single(batch_size, img, label, alpha=0.2):
  916. """
  917. Apply mix up transformation to image and label in single batch internal, One hot encoding should done before this.
  918. Args:
  919. batch_size (int): the batch size of dataset.
  920. img (numpy.ndarray): NumPy image to be applied mix up transformation.
  921. label (numpy.ndarray): NumPy label to be applied mix up transformation.
  922. alpha (float): the mix up rate.
  923. Returns:
  924. mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
  925. mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
  926. """
  927. def cir_shift(data):
  928. index = list(range(1, batch_size)) + [0]
  929. data = data[index, ...]
  930. return data
  931. lam = np.random.beta(alpha, alpha, batch_size)
  932. lam_img = lam.reshape((batch_size, 1, 1, 1))
  933. mix_img = lam_img * img + (1 - lam_img) * cir_shift(img)
  934. lam_label = lam.reshape((batch_size, 1))
  935. mix_label = lam_label * label + (1 - lam_label) * cir_shift(label)
  936. return mix_img, mix_label
  937. def mix_up_muti(tmp, batch_size, img, label, alpha=0.2):
  938. """
  939. Apply mix up transformation to image and label in continuous batch, one hot encoding should done before this.
  940. Args:
  941. tmp (class object): mainly for saving the tmp parameter.
  942. batch_size (int): the batch size of dataset.
  943. img (numpy.ndarray): NumPy image to be applied mix up transformation.
  944. label (numpy.ndarray): NumPy label to be applied mix up transformation.
  945. alpha (float): refer to the mix up rate.
  946. Returns:
  947. mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
  948. mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
  949. """
  950. lam = np.random.beta(alpha, alpha, batch_size)
  951. if tmp.is_first:
  952. lam = np.ones(batch_size)
  953. tmp.is_first = False
  954. lam_img = lam.reshape((batch_size, 1, 1, 1))
  955. mix_img = lam_img * img + (1 - lam_img) * tmp.image
  956. lam_label = lam.reshape(batch_size, 1)
  957. mix_label = lam_label * label + (1 - lam_label) * tmp.label
  958. tmp.image = mix_img
  959. tmp.label = mix_label
  960. return mix_img, mix_label
  961. def rgb_to_hsv(np_rgb_img, is_hwc):
  962. """
  963. Convert RGB img to HSV img.
  964. Args:
  965. np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
  966. is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
  967. Returns:
  968. np_hsv_img (numpy.ndarray), NumPy HSV image with same type of np_rgb_img.
  969. """
  970. if is_hwc:
  971. r, g, b = np_rgb_img[:, :, 0], np_rgb_img[:, :, 1], np_rgb_img[:, :, 2]
  972. else:
  973. r, g, b = np_rgb_img[0, :, :], np_rgb_img[1, :, :], np_rgb_img[2, :, :]
  974. to_hsv = np.vectorize(colorsys.rgb_to_hsv)
  975. h, s, v = to_hsv(r, g, b)
  976. if is_hwc:
  977. axis = 2
  978. else:
  979. axis = 0
  980. np_hsv_img = np.stack((h, s, v), axis=axis)
  981. return np_hsv_img
  982. def rgb_to_hsvs(np_rgb_imgs, is_hwc):
  983. """
  984. Convert RGB imgs to HSV imgs.
  985. Args:
  986. np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
  987. or (C, H, W) or (N, C, H, W) to be converted.
  988. is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C);
  989. If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W).
  990. Returns:
  991. np_hsv_imgs (numpy.ndarray), NumPy HSV images with same type of np_rgb_imgs.
  992. """
  993. if not is_numpy(np_rgb_imgs):
  994. raise TypeError('img should be NumPy image. Got {}'.format(type(np_rgb_imgs)))
  995. shape_size = len(np_rgb_imgs.shape)
  996. if not shape_size in (3, 4):
  997. raise TypeError('img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). \
  998. Got {}'.format(np_rgb_imgs.shape))
  999. if shape_size == 3:
  1000. batch_size = 0
  1001. if is_hwc:
  1002. num_channels = np_rgb_imgs.shape[2]
  1003. else:
  1004. num_channels = np_rgb_imgs.shape[0]
  1005. else:
  1006. batch_size = np_rgb_imgs.shape[0]
  1007. if is_hwc:
  1008. num_channels = np_rgb_imgs.shape[3]
  1009. else:
  1010. num_channels = np_rgb_imgs.shape[1]
  1011. if num_channels != 3:
  1012. raise TypeError('img should be 3 channels RGB img. Got {} channels'.format(num_channels))
  1013. if batch_size == 0:
  1014. return rgb_to_hsv(np_rgb_imgs, is_hwc)
  1015. return np.array([rgb_to_hsv(img, is_hwc) for img in np_rgb_imgs])
  1016. def hsv_to_rgb(np_hsv_img, is_hwc):
  1017. """
  1018. Convert HSV img to RGB img.
  1019. Args:
  1020. np_hsv_img (numpy.ndarray): NumPy HSV image array of shape (H, W, C) or (C, H, W) to be converted.
  1021. is_hwc (Bool): If True, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
  1022. Returns:
  1023. np_rgb_img (numpy.ndarray), NumPy HSV image with same shape of np_hsv_img.
  1024. """
  1025. if is_hwc:
  1026. h, s, v = np_hsv_img[:, :, 0], np_hsv_img[:, :, 1], np_hsv_img[:, :, 2]
  1027. else:
  1028. h, s, v = np_hsv_img[0, :, :], np_hsv_img[1, :, :], np_hsv_img[2, :, :]
  1029. to_rgb = np.vectorize(colorsys.hsv_to_rgb)
  1030. r, g, b = to_rgb(h, s, v)
  1031. if is_hwc:
  1032. axis = 2
  1033. else:
  1034. axis = 0
  1035. np_rgb_img = np.stack((r, g, b), axis=axis)
  1036. return np_rgb_img
  1037. def hsv_to_rgbs(np_hsv_imgs, is_hwc):
  1038. """
  1039. Convert HSV imgs to RGB imgs.
  1040. Args:
  1041. np_hsv_imgs (numpy.ndarray): NumPy HSV images array of shape (H, W, C) or (N, H, W, C),
  1042. or (C, H, W) or (N, C, H, W) to be converted.
  1043. is_hwc (Bool): If True, the shape of np_hsv_imgs is (H, W, C) or (N, H, W, C);
  1044. If False, the shape of np_hsv_imgs is (C, H, W) or (N, C, H, W).
  1045. Returns:
  1046. np_rgb_imgs (numpy.ndarray), NumPy RGB images with same type of np_hsv_imgs.
  1047. """
  1048. if not is_numpy(np_hsv_imgs):
  1049. raise TypeError('img should be NumPy image. Got {}'.format(type(np_hsv_imgs)))
  1050. shape_size = len(np_hsv_imgs.shape)
  1051. if not shape_size in (3, 4):
  1052. raise TypeError('img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). \
  1053. Got {}'.format(np_hsv_imgs.shape))
  1054. if shape_size == 3:
  1055. batch_size = 0
  1056. if is_hwc:
  1057. num_channels = np_hsv_imgs.shape[2]
  1058. else:
  1059. num_channels = np_hsv_imgs.shape[0]
  1060. else:
  1061. batch_size = np_hsv_imgs.shape[0]
  1062. if is_hwc:
  1063. num_channels = np_hsv_imgs.shape[3]
  1064. else:
  1065. num_channels = np_hsv_imgs.shape[1]
  1066. if num_channels != 3:
  1067. raise TypeError('img should be 3 channels RGB img. Got {} channels'.format(num_channels))
  1068. if batch_size == 0:
  1069. return hsv_to_rgb(np_hsv_imgs, is_hwc)
  1070. return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs])
  1071. def random_color(img, degrees):
  1072. """
  1073. Adjust the color of the input PIL image by a random degree.
  1074. Args:
  1075. img (PIL image): Image to be color adjusted.
  1076. degrees (sequence): Range of random color adjustment degrees.
  1077. It should be in (min, max) format (default=(0.1,1.9)).
  1078. Returns:
  1079. img (PIL image), Color adjusted image.
  1080. """
  1081. if not is_pil(img):
  1082. raise TypeError('img should be PIL image. Got {}'.format(type(img)))
  1083. v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
  1084. return ImageEnhance.Color(img).enhance(v)
  1085. def random_sharpness(img, degrees):
  1086. """
  1087. Adjust the sharpness of the input PIL image by a random degree.
  1088. Args:
  1089. img (PIL image): Image to be sharpness adjusted.
  1090. degrees (sequence): Range of random sharpness adjustment degrees.
  1091. It should be in (min, max) format (default=(0.1,1.9)).
  1092. Returns:
  1093. img (PIL image), Sharpness adjusted image.
  1094. """
  1095. if not is_pil(img):
  1096. raise TypeError('img should be PIL image. Got {}'.format(type(img)))
  1097. v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
  1098. return ImageEnhance.Sharpness(img).enhance(v)
  1099. def auto_contrast(img, cutoff, ignore):
  1100. """
  1101. Automatically maximize the contrast of the input PIL image.
  1102. Args:
  1103. img (PIL image): Image to be augmented with AutoContrast.
  1104. cutoff (float, optional): Percent of pixels to cut off from the histogram (default=0.0).
  1105. ignore (Union[int, sequence], optional): Pixel values to ignore (default=None).
  1106. Returns:
  1107. img (PIL image), Augmented image.
  1108. """
  1109. if not is_pil(img):
  1110. raise TypeError('img should be PIL image. Got {}'.format(type(img)))
  1111. return ImageOps.autocontrast(img, cutoff, ignore)
  1112. def invert_color(img):
  1113. """
  1114. Invert colors of input PIL image.
  1115. Args:
  1116. img (PIL image): Image to be color inverted.
  1117. Returns:
  1118. img (PIL image), Color inverted image.
  1119. """
  1120. if not is_pil(img):
  1121. raise TypeError('img should be PIL image. Got {}'.format(type(img)))
  1122. return ImageOps.invert(img)
  1123. def equalize(img):
  1124. """
  1125. Equalize the histogram of input PIL image.
  1126. Args:
  1127. img (PIL image): Image to be equalized
  1128. Returns:
  1129. img (PIL image), Equalized image.
  1130. """
  1131. if not is_pil(img):
  1132. raise TypeError('img should be PIL image. Got {}'.format(type(img)))
  1133. return ImageOps.equalize(img)
  1134. def uniform_augment(img, transforms, num_ops):
  1135. """
  1136. Uniformly select and apply a number of transforms sequentially from
  1137. a list of transforms. Randomly assigns a probability to each transform for
  1138. each image to decide whether apply it or not.
  1139. All the transforms in transform list must have the same input/output data type.
  1140. Args:
  1141. img: Image to be applied transformation.
  1142. transforms (list): List of transformations to be chosen from to apply.
  1143. num_ops (int): number of transforms to sequentially aaply.
  1144. Returns:
  1145. img, Transformed image.
  1146. """
  1147. op_idx = np.random.choice(len(transforms), size=num_ops, replace=False)
  1148. for idx in op_idx:
  1149. AugmentOp = transforms[idx]
  1150. pr = random.random()
  1151. if random.random() < pr:
  1152. img = AugmentOp(img.copy())
  1153. return img