You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

transform.py 35 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections.abc
  10. import math
  11. from typing import Sequence, Tuple
  12. import cv2
  13. import numpy as np
  14. from megengine.data.transform import Transform
  15. from megengine.data.transform.vision import functional as F
  16. __all__ = [
  17. "VisionTransform",
  18. "ToMode",
  19. "Compose",
  20. "TorchTransformCompose",
  21. "Pad",
  22. "Resize",
  23. "ShortestEdgeResize",
  24. "RandomResize",
  25. "RandomCrop",
  26. "RandomResizedCrop",
  27. "CenterCrop",
  28. "RandomHorizontalFlip",
  29. "RandomVerticalFlip",
  30. "Normalize",
  31. "GaussianNoise",
  32. "BrightnessTransform",
  33. "SaturationTransform",
  34. "ContrastTransform",
  35. "HueTransform",
  36. "ColorJitter",
  37. "Lighting",
  38. ]
  39. class VisionTransform(Transform):
  40. r"""
  41. Base class of all transforms used in computer vision.
  42. calling logic: apply_batch() -> apply() -> _apply_image() and other _apply_*()
  43. method. If you want to implement a self-defined transform method for image,
  44. rewrite _apply_image method in subclass.
  45. :param order: Input type order. Input is a tuple contains different structures,
  46. order is used to specify the order of structures. For example, if your input
  47. is (image, boxes) type, then the order should be ("image", "boxes").
  48. Current available strings & data type are describe below:
  49. "image":
  50. input image, with shape of (H, W, C)
  51. "coords":
  52. coordinates, with shape of (N, 2)
  53. "boxes":
  54. bounding boxes, with shape of (N, 4), "xyxy" format,
  55. the 1st "xy" represents top left point of a box,
  56. the 2nd "xy" represents right bottom point.
  57. "mask":
  58. map used for segmentation, with shape of (H, W, 1)
  59. "keypoints":
  60. keypoints with shape of (N, K, 3), N for number of instances, and K for number of keypoints in one instance. The first two dimensions
  61. of last axis is coordinate of keypoints and the the 3rd dimension is
  62. the label of keypoints.
  63. "polygons": A sequence contains numpy array, its length is number of instances.
  64. Each numpy array represents polygon coordinate of one instance.
  65. "category": categories for some data type. For example, "image_category"
  66. means category of the input image and "boxes_category" means categories of
  67. bounding boxes.
  68. "info":
  69. information for images such as image shapes and image path.
  70. You can also customize your data types only if you implement the corresponding
  71. _apply_*() methods, otherwise ``NotImplementedError`` will be raised.
  72. """
  73. def __init__(self, order=None):
  74. super().__init__()
  75. if order is None:
  76. order = ("image",)
  77. elif not isinstance(order, collections.abc.Sequence):
  78. raise ValueError(
  79. "order should be a sequence, but got order={}".format(order)
  80. )
  81. for k in order:
  82. if k in ("batch",):
  83. raise ValueError("{} is invalid data type".format(k))
  84. elif k.endswith("category") or k.endswith("info"):
  85. # when the key is *category or info, we should do nothing
  86. # if the corresponding apply methods are not implemented.
  87. continue
  88. elif self._get_apply(k) is None:
  89. raise NotImplementedError("{} is unsupported data type".format(k))
  90. self.order = order
  91. def apply_batch(self, inputs: Sequence[Tuple]):
  92. r"""Apply transform on batch input data"""
  93. return tuple(self.apply(input) for input in inputs)
  94. def apply(self, input: Tuple):
  95. r"""Apply transform on single input data"""
  96. if not isinstance(input, tuple):
  97. input = (input,)
  98. output = []
  99. for i in range(min(len(input), len(self.order))):
  100. apply_func = self._get_apply(self.order[i])
  101. if apply_func is None:
  102. output.append(input[i])
  103. else:
  104. output.append(apply_func(input[i]))
  105. if len(input) > len(self.order):
  106. output.extend(input[len(self.order) :])
  107. if len(output) == 1:
  108. output = output[0]
  109. else:
  110. output = tuple(output)
  111. return output
  112. def _get_apply(self, key):
  113. return getattr(self, "_apply_{}".format(key), None)
  114. def _get_image(self, input: Tuple):
  115. if not isinstance(input, tuple):
  116. input = (input,)
  117. return input[self.order.index("image")]
  118. def _apply_image(self, image):
  119. raise NotImplementedError
  120. def _apply_coords(self, coords):
  121. raise NotImplementedError
  122. def _apply_boxes(self, boxes):
  123. idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
  124. coords = np.asarray(boxes).reshape(-1, 4)[:, idxs].reshape(-1, 2)
  125. coords = self._apply_coords(coords).reshape((-1, 4, 2))
  126. minxy = coords.min(axis=1)
  127. maxxy = coords.max(axis=1)
  128. trans_boxes = np.concatenate((minxy, maxxy), axis=1)
  129. return trans_boxes
  130. def _apply_mask(self, mask):
  131. raise NotImplementedError
  132. def _apply_keypoints(self, keypoints):
  133. coords, visibility = keypoints[..., :2], keypoints[..., 2:]
  134. trans_coords = [self._apply_coords(p) for p in coords]
  135. return np.concatenate((trans_coords, visibility), axis=-1)
  136. def _apply_polygons(self, polygons):
  137. return [[self._apply_coords(p) for p in instance] for instance in polygons]
  138. class ToMode(VisionTransform):
  139. r"""Change input data to a target mode.
  140. For example, most transforms use HWC mode image,
  141. while the Neural Network might use CHW mode input tensor
  142. :param mode: Output mode of input. Use "CHW" mode by default.
  143. :param order: The same with ``VisionTransform``
  144. """
  145. def __init__(self, mode="CHW", *, order=None):
  146. super().__init__(order)
  147. assert mode in ["CHW"], "unsupported mode: {}".format(mode)
  148. self.mode = mode
  149. def _apply_image(self, image):
  150. if self.mode == "CHW":
  151. return np.ascontiguousarray(np.rollaxis(image, 2))
  152. return image
  153. def _apply_coords(self, coords):
  154. return coords
  155. def _apply_mask(self, mask):
  156. if self.mode == "CHW":
  157. return np.ascontiguousarray(np.rollaxis(mask, 2))
  158. return mask
  159. class Compose(VisionTransform):
  160. r"""
  161. Composes several transforms together.
  162. :param transforms: List of ``VisionTransform`` to compose.
  163. :param batch_compose: Whether use shuffle_indices for batch data or not.
  164. If True, use original input sequence.
  165. Otherwise, the shuffle_indices will be used for transforms.
  166. :param shuffle_indices: Indices used for random shuffle, start at 1.
  167. For example, if shuffle_indices is [(1, 3), (2, 4)], then the 1st and 3rd transform
  168. will be random shuffled, the 2nd and 4th transform will also be shuffled.
  169. :param order: The same with ``VisionTransform``
  170. Example:
  171. ..testcode::
  172. from megengine.data.transform import RandomHorizontalFlip, RandomVerticalFlip, CenterCrop, ToMode, Compose
  173. transform_func = Compose([
  174. RandomHorizontalFlip(),
  175. RandomVerticalFlip(),
  176. CenterCrop(100),
  177. ToMode("CHW"),
  178. ],
  179. shuffle_indices=[(1, 2, 3)]
  180. )
  181. """
  182. def __init__(
  183. self, transforms=[], batch_compose=False, shuffle_indices=None, *, order=None
  184. ):
  185. super().__init__(order)
  186. self.transforms = transforms
  187. self._set_order()
  188. if batch_compose and shuffle_indices is not None:
  189. raise ValueError(
  190. "Do not support shuffle when apply transforms along the whole batch"
  191. )
  192. self.batch_compose = batch_compose
  193. if shuffle_indices is not None:
  194. shuffle_indices = [tuple(x - 1 for x in idx) for idx in shuffle_indices]
  195. self.shuffle_indices = shuffle_indices
  196. def _set_order(self):
  197. for t in self.transforms:
  198. t.order = self.order
  199. if isinstance(t, Compose):
  200. t._set_order()
  201. def apply_batch(self, inputs: Sequence[Tuple]):
  202. if self.batch_compose:
  203. for t in self.transforms:
  204. inputs = t.apply_batch(inputs)
  205. return inputs
  206. else:
  207. return super().apply_batch(inputs)
  208. def apply(self, input: Tuple):
  209. for t in self._shuffle():
  210. input = t.apply(input)
  211. return input
  212. def _shuffle(self):
  213. if self.shuffle_indices is not None:
  214. source_idx = list(range(len(self.transforms)))
  215. for idx in self.shuffle_indices:
  216. shuffled = np.random.permutation(idx).tolist()
  217. for src, dst in zip(idx, shuffled):
  218. source_idx[src] = dst
  219. return [self.transforms[i] for i in source_idx]
  220. else:
  221. return self.transforms
  222. class TorchTransformCompose(VisionTransform):
  223. r"""
  224. Compose class used for transforms in torchvision, only support PIL image,
  225. some transforms with tensor in torchvision are not supported,
  226. such as Normalize and ToTensor in torchvision.
  227. :param transforms: The same with ``Compose``
  228. :param order: The same with ``VisionTransform``
  229. """
  230. def __init__(self, transforms, *, order=None):
  231. super().__init__(order)
  232. self.transforms = transforms
  233. def _apply_image(self, image):
  234. from PIL import Image
  235. try:
  236. import accimage
  237. except ImportError:
  238. accimage = None
  239. if image.shape[0] == 3: # CHW
  240. image = np.ascontiguousarray(image[[2, 1, 0]])
  241. elif image.shape[2] == 3: # HWC
  242. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  243. image = Image.fromarray(image.astype(np.uint8))
  244. for t in self.transforms:
  245. image = t(image)
  246. if isinstance(image, Image.Image) or (
  247. accimage is not None and isinstance(image, accimage.Image)
  248. ):
  249. image = np.array(image, dtype=np.uint8)
  250. if image.shape[0] == 3: # CHW
  251. image = np.ascontiguousarray(image[[2, 1, 0]])
  252. elif image.shape[2] == 3: # HWC
  253. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  254. return image
  255. class Pad(VisionTransform):
  256. r"""Pad the input data.
  257. :param size: Padding size of input image, it could be integer or sequence.
  258. If it's an integer, the input image will be padded in four directions.
  259. If it's a sequence contains two integer, the bottom and right side
  260. of image will be padded.
  261. If it's a sequence contains four integer, the top, bottom, left, right
  262. side of image will be padded with given size.
  263. :param value: Padding value of image, could be a sequence of int or float.
  264. if it's float value, the dtype of image will be casted to float32 also.
  265. :param mask_value: Padding value of segmentation map.
  266. :param order: The same with ``VisionTransform``
  267. """
  268. def __init__(self, size=0, value=0, mask_value=0, *, order=None):
  269. super().__init__(order)
  270. if isinstance(size, int):
  271. size = (size, size, size, size)
  272. elif isinstance(size, collections.abc.Sequence) and len(size) == 2:
  273. size = (0, size[0], 0, size[1])
  274. elif not (isinstance(size, collections.abc.Sequence) and len(size) == 4):
  275. raise ValueError(
  276. "size should be a list/tuple which contains "
  277. "(top, down, left, right) four pad sizes."
  278. )
  279. self.size = size
  280. self.value = value
  281. if not isinstance(mask_value, int):
  282. raise ValueError(
  283. "mask_value should be a positive integer, "
  284. "but got mask_value={}".format(mask_value)
  285. )
  286. self.mask_value = mask_value
  287. def _apply_image(self, image):
  288. return F.pad(image, self.size, self.value)
  289. def _apply_coords(self, coords):
  290. coords[:, 0] += self.size[2]
  291. coords[:, 1] += self.size[0]
  292. return coords
  293. def _apply_mask(self, mask):
  294. return F.pad(mask, self.size, self.mask_value)
  295. class Resize(VisionTransform):
  296. r"""Resize the input data.
  297. :param output_size: Target size of image, with (height, width) shape.
  298. :param interpolation: Interpolation method. All methods are listed below:
  299. * cv2.INTER_NEAREST – a nearest-neighbor interpolation.
  300. * cv2.INTER_LINEAR – a bilinear interpolation (used by default).
  301. * cv2.INTER_AREA – resampling using pixel area relation.
  302. * cv2.INTER_CUBIC – a bicubic interpolation over 4×4 pixel neighborhood.
  303. * cv2.INTER_LANCZOS4 – a Lanczos interpolation over 8×8 pixel neighborhood.
  304. :param order: The same with ``VisionTransform``
  305. """
  306. def __init__(self, output_size, interpolation=cv2.INTER_LINEAR, *, order=None):
  307. super().__init__(order)
  308. self.output_size = output_size
  309. self.interpolation = interpolation
  310. def apply(self, input: Tuple):
  311. self._shape_info = self._get_shape(self._get_image(input))
  312. return super().apply(input)
  313. def _apply_image(self, image):
  314. h, w, th, tw = self._shape_info
  315. if h == th and w == tw:
  316. return image
  317. return F.resize(image, (th, tw), self.interpolation)
  318. def _apply_coords(self, coords):
  319. h, w, th, tw = self._shape_info
  320. if h == th and w == tw:
  321. return coords
  322. coords[:, 0] = coords[:, 0] * (tw / w)
  323. coords[:, 1] = coords[:, 1] * (th / h)
  324. return coords
  325. def _apply_mask(self, mask):
  326. h, w, th, tw = self._shape_info
  327. if h == th and w == tw:
  328. return mask
  329. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  330. def _get_shape(self, image):
  331. h, w, _ = image.shape
  332. if isinstance(self.output_size, int):
  333. if min(h, w) == self.output_size:
  334. return h, w, h, w
  335. if h < w:
  336. th = self.output_size
  337. tw = int(self.output_size * w / h)
  338. else:
  339. tw = self.output_size
  340. th = int(self.output_size * h / w)
  341. return h, w, th, tw
  342. else:
  343. return (h, w, *self.output_size)
  344. class ShortestEdgeResize(VisionTransform):
  345. def __init__(
  346. self,
  347. min_size,
  348. max_size,
  349. sample_style="range",
  350. interpolation=cv2.INTER_LINEAR,
  351. *,
  352. order=None
  353. ):
  354. super().__init__(order)
  355. if sample_style not in ("range", "choice"):
  356. raise NotImplementedError(
  357. "{} is unsupported sample style".format(sample_style)
  358. )
  359. self.sample_style = sample_style
  360. if isinstance(min_size, int):
  361. min_size = (min_size, min_size)
  362. self.min_size = min_size
  363. self.max_size = max_size
  364. self.interpolation = interpolation
  365. def apply(self, input: Tuple):
  366. self._shape_info = self._get_shape(self._get_image(input))
  367. return super().apply(input)
  368. def _apply_image(self, image):
  369. h, w, th, tw = self._shape_info
  370. if h == th and w == tw:
  371. return image
  372. return F.resize(image, (th, tw), self.interpolation)
  373. def _apply_coords(self, coords):
  374. h, w, th, tw = self._shape_info
  375. if h == th and w == tw:
  376. return coords
  377. coords[:, 0] = coords[:, 0] * (tw / w)
  378. coords[:, 1] = coords[:, 1] * (th / h)
  379. return coords
  380. def _apply_mask(self, mask):
  381. h, w, th, tw = self._shape_info
  382. if h == th and w == tw:
  383. return mask
  384. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  385. def _get_shape(self, image):
  386. h, w, _ = image.shape
  387. if self.sample_style == "range":
  388. size = np.random.randint(self.min_size[0], self.min_size[1] + 1)
  389. else:
  390. size = np.random.choice(self.min_size)
  391. scale = size / min(h, w)
  392. if h < w:
  393. th, tw = size, scale * w
  394. else:
  395. th, tw = scale * h, size
  396. if max(th, tw) > self.max_size:
  397. scale = self.max_size / max(th, tw)
  398. th = th * scale
  399. tw = tw * scale
  400. th = int(round(th))
  401. tw = int(round(tw))
  402. return h, w, th, tw
  403. class RandomResize(VisionTransform):
  404. r"""Resize the input data randomly.
  405. :param scale_range: .
  406. :param order: The same with ``VisionTransform``
  407. """
  408. def __init__(self, scale_range, interpolation=cv2.INTER_LINEAR, *, order=None):
  409. super().__init__(order)
  410. self.scale_range = scale_range
  411. self.interpolation = interpolation
  412. def apply(self, input: Tuple):
  413. self._shape_info = self._get_shape(self._get_image(input))
  414. return super().apply(input)
  415. def _apply_image(self, image):
  416. h, w, th, tw = self._shape_info
  417. if h == th and w == tw:
  418. return image
  419. return F.resize(image, (th, tw), self.interpolation)
  420. def _apply_coords(self, coords):
  421. h, w, th, tw = self._shape_info
  422. if h == th and w == tw:
  423. return coords
  424. coords[:, 0] = coords[:, 0] * (tw / w)
  425. coords[:, 1] = coords[:, 1] * (th / h)
  426. return coords
  427. def _apply_mask(self, mask):
  428. h, w, th, tw = self._shape_info
  429. if h == th and w == tw:
  430. return mask
  431. return F.resize(mask, (th, tw), cv2.INTER_NEAREST)
  432. def _get_shape(self, image):
  433. h, w, _ = image.shape
  434. scale = np.random.uniform(*self.scale_range)
  435. th = int(round(h * scale))
  436. tw = int(round(w * scale))
  437. return h, w, th, tw
  438. class RandomCrop(VisionTransform):
  439. r"""Crop the input data randomly. Before applying the crop transform,
  440. pad the image first. And if target size is still bigger than the size of
  441. padded image, pad the image size to target size.
  442. :param output_size: Target size of output image, with (height, width) shape.
  443. :param padding_size: The same with `size` in ``Pad``
  444. :param padding_value: The same with `value` in ``Pad``
  445. :param order: The same with ``VisionTransform``
  446. """
  447. def __init__(
  448. self,
  449. output_size,
  450. padding_size=0,
  451. padding_value=[0, 0, 0],
  452. padding_maskvalue=0,
  453. *,
  454. order=None
  455. ):
  456. super().__init__(order)
  457. if isinstance(output_size, int):
  458. self.output_size = (output_size, output_size)
  459. else:
  460. self.output_size = output_size
  461. self.pad = Pad(padding_size, padding_value, order=self.order)
  462. self.padding_value = padding_value
  463. self.padding_maskvalue = padding_maskvalue
  464. def apply(self, input):
  465. input = self.pad.apply(input)
  466. self._h, self._w, _ = self._get_image(input).shape
  467. self._th, self._tw = self.output_size
  468. self._x = np.random.randint(0, max(0, self._w - self._tw) + 1)
  469. self._y = np.random.randint(0, max(0, self._h - self._th) + 1)
  470. return super().apply(input)
  471. def _apply_image(self, image):
  472. if self._th > self._h:
  473. image = F.pad(image, (self._th - self._h, 0), self.padding_value)
  474. if self._tw > self._w:
  475. image = F.pad(image, (0, self._tw - self._w), self.padding_value)
  476. return image[self._y : self._y + self._th, self._x : self._x + self._tw]
  477. def _apply_coords(self, coords):
  478. coords[:, 0] -= self._x
  479. coords[:, 1] -= self._y
  480. return coords
  481. def _apply_mask(self, mask):
  482. if self._th > self._h:
  483. mask = F.pad(mask, (self._th - self._h, 0), self.padding_maskvalue)
  484. if self._tw > self._w:
  485. mask = F.pad(mask, (0, self._tw - self._w), self.padding_maskvalue)
  486. return mask[self._y : self._y + self._th, self._x : self._x + self._tw]
  487. class RandomResizedCrop(VisionTransform):
  488. r"""Crop the input data to random size and aspect ratio.
  489. A crop of random size (default: of 0.08 to 1.0) of the original size and a random
  490. aspect ratio (default: of 3/4 to 1.33) of the original aspect ratio is made.
  491. After applying crop transfrom, the input data will be resized to given size.
  492. :param output_size: Target size of output image, with (height, width) shape.
  493. :param scale_range: Range of size of the origin size cropped. Default: (0.08, 1.0)
  494. :param ratio_range: Range of aspect ratio of the origin aspect ratio cropped. Default: (0.75, 1.33)
  495. :param order: The same with ``VisionTransform``
  496. """
  497. def __init__(
  498. self,
  499. output_size,
  500. scale_range=(0.08, 1.0),
  501. ratio_range=(3.0 / 4, 4.0 / 3),
  502. interpolation=cv2.INTER_LINEAR,
  503. *,
  504. order=None
  505. ):
  506. super().__init__(order)
  507. if isinstance(output_size, int):
  508. self.output_size = (output_size, output_size)
  509. else:
  510. self.output_size = output_size
  511. assert (
  512. scale_range[0] <= scale_range[1]
  513. ), "scale_range should be of kind (min, max)"
  514. assert (
  515. ratio_range[0] <= ratio_range[1]
  516. ), "ratio_range should be of kind (min, max)"
  517. self.scale_range = scale_range
  518. self.ratio_range = ratio_range
  519. self.interpolation = interpolation
  520. def apply(self, input: Tuple):
  521. self._coord_info = self._get_coord(self._get_image(input))
  522. return super().apply(input)
  523. def _apply_image(self, image):
  524. x, y, w, h = self._coord_info
  525. cropped_img = image[y : y + h, x : x + w]
  526. return F.resize(cropped_img, self.output_size, self.interpolation)
  527. def _apply_coords(self, coords):
  528. x, y, w, h = self._coord_info
  529. coords[:, 0] = (coords[:, 0] - x) * self.output_size[1] / w
  530. coords[:, 1] = (coords[:, 1] - y) * self.output_size[0] / h
  531. return coords
  532. def _apply_mask(self, mask):
  533. x, y, w, h = self._coord_info
  534. cropped_mask = mask[y : y + h, x : x + w]
  535. return F.resize(cropped_mask, self.output_size, cv2.INTER_NEAREST)
  536. def _get_coord(self, image, attempts=10):
  537. height, width, _ = image.shape
  538. area = height * width
  539. for _ in range(attempts):
  540. target_area = np.random.uniform(*self.scale_range) * area
  541. log_ratio = tuple(math.log(x) for x in self.ratio_range)
  542. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  543. w = int(round(math.sqrt(target_area * aspect_ratio)))
  544. h = int(round(math.sqrt(target_area / aspect_ratio)))
  545. if 0 < w <= width and 0 < h <= height:
  546. x = np.random.randint(0, width - w + 1)
  547. y = np.random.randint(0, height - h + 1)
  548. return x, y, w, h
  549. # Fallback to central crop
  550. in_ratio = float(width) / float(height)
  551. if in_ratio < min(self.ratio_range):
  552. w = width
  553. h = int(round(w / min(self.ratio_range)))
  554. elif in_ratio > max(self.ratio_range):
  555. h = height
  556. w = int(round(h * max(self.ratio_range)))
  557. else: # whole image
  558. w = width
  559. h = height
  560. x = (width - w) // 2
  561. y = (height - h) // 2
  562. return x, y, w, h
  563. class CenterCrop(VisionTransform):
  564. r"""Crops the given the input data at the center.
  565. :param output_size: Target size of output image, with (height, width) shape.
  566. :param order: The same with ``VisionTransform``
  567. """
  568. def __init__(self, output_size, *, order=None):
  569. super().__init__(order)
  570. if isinstance(output_size, int):
  571. self.output_size = (output_size, output_size)
  572. else:
  573. self.output_size = output_size
  574. def apply(self, input: Tuple):
  575. self._coord_info = self._get_coord(self._get_image(input))
  576. return super().apply(input)
  577. def _apply_image(self, image):
  578. x, y = self._coord_info
  579. th, tw = self.output_size
  580. return image[y : y + th, x : x + tw]
  581. def _apply_coords(self, coords):
  582. x, y = self._coord_info
  583. coords[:, 0] -= x
  584. coords[:, 1] -= y
  585. return coords
  586. def _apply_mask(self, mask):
  587. x, y = self._coord_info
  588. th, tw = self.output_size
  589. return mask[y : y + th, x : x + tw]
  590. def _get_coord(self, image):
  591. th, tw = self.output_size
  592. h, w, _ = image.shape
  593. assert th <= h and tw <= w, "output size is bigger than image size"
  594. x = int(round((w - tw) / 2.0))
  595. y = int(round((h - th) / 2.0))
  596. return x, y
  597. class RandomHorizontalFlip(VisionTransform):
  598. r"""Horizontally flip the input data randomly with a given probability.
  599. :param p: probability of the input data being flipped. Default: 0.5
  600. :param order: The same with ``VisionTransform``
  601. """
  602. def __init__(self, prob: float = 0.5, *, order=None):
  603. super().__init__(order)
  604. self.prob = prob
  605. def apply(self, input: Tuple):
  606. self._flipped = np.random.random() < self.prob
  607. self._w = self._get_image(input).shape[1]
  608. return super().apply(input)
  609. def _apply_image(self, image):
  610. if self._flipped:
  611. return F.flip(image, flipCode=1)
  612. return image
  613. def _apply_coords(self, coords):
  614. if self._flipped:
  615. coords[:, 0] = self._w - coords[:, 0]
  616. return coords
  617. def _apply_mask(self, mask):
  618. if self._flipped:
  619. return F.flip(mask, flipCode=1)
  620. return mask
  621. class RandomVerticalFlip(VisionTransform):
  622. r"""Vertically flip the input data randomly with a given probability.
  623. :param p: probability of the input data being flipped. Default: 0.5
  624. :param order: The same with ``VisionTransform``
  625. """
  626. def __init__(self, prob: float = 0.5, *, order=None):
  627. super().__init__(order)
  628. self.prob = prob
  629. def apply(self, input: Tuple):
  630. self._flipped = np.random.random() < self.prob
  631. self._h = self._get_image(input).shape[0]
  632. return super().apply(input)
  633. def _apply_image(self, image):
  634. if self._flipped:
  635. return F.flip(image, flipCode=0)
  636. return image
  637. def _apply_coords(self, coords):
  638. if self._flipped:
  639. coords[:, 1] = self._h - coords[:, 1]
  640. return coords
  641. def _apply_mask(self, mask):
  642. if self._flipped:
  643. return F.flip(mask, flipCode=0)
  644. return mask
  645. class Normalize(VisionTransform):
  646. r"""Normalize the input data with mean and standard deviation.
  647. Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
  648. this transform will normalize each channel of the input data.
  649. ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
  650. :param mean: Sequence of means for each channel.
  651. :param std: Sequence of standard deviations for each channel.
  652. :param order: The same with ``VisionTransform``
  653. """
  654. def __init__(self, mean=0.0, std=1.0, *, order=None):
  655. super().__init__(order)
  656. self.mean = np.array(mean, dtype=np.float32)
  657. self.std = np.array(std, dtype=np.float32)
  658. def _apply_image(self, image):
  659. return (image - self.mean) / self.std
  660. def _apply_coords(self, coords):
  661. return coords
  662. def _apply_mask(self, mask):
  663. return mask
  664. class GaussianNoise(VisionTransform):
  665. r"""Add random gaussian noise to the input data.
  666. Gaussian noise is generated with given mean and std.
  667. :param mean: Gaussian mean used to generate noise.
  668. :param std: Gaussian standard deviation used to generate noise.
  669. :param order: The same with ``VisionTransform``
  670. """
  671. def __init__(self, mean=0.0, std=1.0, *, order=None):
  672. super().__init__(order)
  673. self.mean = np.array(mean, dtype=np.float32)
  674. self.std = np.array(std, dtype=np.float32)
  675. def _apply_image(self, image):
  676. dtype = image.dtype
  677. noise = np.random.normal(self.mean, self.std, image.shape) * 255
  678. image = image + noise.astype(np.float32)
  679. return np.clip(image, 0, 255).astype(dtype)
  680. def _apply_coords(self, coords):
  681. return coords
  682. def _apply_mask(self, mask):
  683. return mask
  684. class BrightnessTransform(VisionTransform):
  685. r"""Adjust brightness of the input data.
  686. :param value: How much to adjust the brightness. Can be any
  687. non negative number. 0 gives the original image
  688. :param order: The same with ``VisionTransform``
  689. """
  690. def __init__(self, value, *, order=None):
  691. super().__init__(order)
  692. if value < 0:
  693. raise ValueError("brightness value should be non-negative")
  694. self.value = value
  695. def _apply_image(self, image):
  696. if self.value == 0:
  697. return image
  698. dtype = image.dtype
  699. image = image.astype(np.float32)
  700. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  701. image = image * alpha
  702. return image.clip(0, 255).astype(dtype)
  703. def _apply_coords(self, coords):
  704. return coords
  705. def _apply_mask(self, mask):
  706. return mask
  707. class ContrastTransform(VisionTransform):
  708. r"""Adjust contrast of the input data.
  709. :param value: How much to adjust the contrast. Can be any
  710. non negative number. 0 gives the original image
  711. :param order: The same with ``VisionTransform``
  712. """
  713. def __init__(self, value, *, order=None):
  714. super().__init__(order)
  715. if value < 0:
  716. raise ValueError("contrast value should be non-negative")
  717. self.value = value
  718. def _apply_image(self, image):
  719. if self.value == 0:
  720. return image
  721. dtype = image.dtype
  722. image = image.astype(np.float32)
  723. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  724. image = image * alpha + F.to_gray(image).mean() * (1 - alpha)
  725. return image.clip(0, 255).astype(dtype)
  726. def _apply_coords(self, coords):
  727. return coords
  728. def _apply_mask(self, mask):
  729. return mask
  730. class SaturationTransform(VisionTransform):
  731. r"""Adjust saturation of the input data.
  732. :param value: How much to adjust the saturation. Can be any
  733. non negative number. 0 gives the original image
  734. :param order: The same with ``VisionTransform``
  735. """
  736. def __init__(self, value, *, order=None):
  737. super().__init__(order)
  738. if value < 0:
  739. raise ValueError("saturation value should be non-negative")
  740. self.value = value
  741. def _apply_image(self, image):
  742. if self.value == 0:
  743. return image
  744. dtype = image.dtype
  745. image = image.astype(np.float32)
  746. alpha = np.random.uniform(max(0, 1 - self.value), 1 + self.value)
  747. image = image * alpha + F.to_gray(image) * (1 - alpha)
  748. return image.clip(0, 255).astype(dtype)
  749. def _apply_coords(self, coords):
  750. return coords
  751. def _apply_mask(self, mask):
  752. return mask
  753. class HueTransform(VisionTransform):
  754. r"""Adjust hue of the input data.
  755. :param value: How much to adjust the hue. Can be any number
  756. between 0 and 0.5, 0 gives the original image
  757. :param order: The same with ``VisionTransform``
  758. """
  759. def __init__(self, value, *, order=None):
  760. super().__init__(order)
  761. if value < 0 or value > 0.5:
  762. raise ValueError("hue value should be in [0.0, 0.5]")
  763. self.value = value
  764. def _apply_image(self, image):
  765. if self.value == 0:
  766. return image
  767. dtype = image.dtype
  768. image = image.astype(np.uint8)
  769. hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV_FULL)
  770. h, s, v = cv2.split(hsv_image)
  771. alpha = np.random.uniform(-self.value, self.value)
  772. h = h.astype(np.uint8)
  773. # uint8 addition take cares of rotation across boundaries
  774. with np.errstate(over="ignore"):
  775. h += np.uint8(alpha * 255)
  776. hsv_image = cv2.merge([h, s, v])
  777. return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR_FULL).astype(dtype)
  778. def _apply_coords(self, coords):
  779. return coords
  780. def _apply_mask(self, mask):
  781. return mask
  782. class ColorJitter(VisionTransform):
  783. r"""Randomly change the brightness, contrast, saturation and hue of an image.
  784. :param brightness: How much to jitter brightness.
  785. Chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
  786. or the given [min, max]. Should be non negative numbers.
  787. :param contrast: How much to jitter contrast.
  788. Chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
  789. or the given [min, max]. Should be non negative numbers.
  790. :param saturation: How much to jitter saturation.
  791. Chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
  792. or the given [min, max]. Should be non negative numbers.
  793. :param hue: How much to jitter hue.
  794. Chosen uniformly from [-hue, hue] or the given [min, max].
  795. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
  796. :param order: The same with ``VisionTransform``
  797. """
  798. def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, *, order=None):
  799. super().__init__(order)
  800. transforms = []
  801. if brightness != 0:
  802. transforms.append(BrightnessTransform(brightness))
  803. if contrast != 0:
  804. transforms.append(ContrastTransform(contrast))
  805. if saturation != 0:
  806. transforms.append(SaturationTransform(saturation))
  807. if hue != 0:
  808. transforms.append(HueTransform(hue))
  809. self.transforms = Compose(
  810. transforms,
  811. shuffle_indices=[tuple(range(1, len(transforms) + 1))],
  812. order=order,
  813. )
  814. def apply(self, input):
  815. return self.transforms.apply(input)
  816. class Lighting(VisionTransform):
  817. def __init__(self, scale, *, order=None):
  818. super().__init__(order)
  819. if scale < 0:
  820. raise ValueError("lighting scale should be non-negative")
  821. self.scale = scale
  822. self.eigvec = np.array(
  823. [
  824. [-0.5836, -0.6948, 0.4203],
  825. [-0.5808, -0.0045, -0.8140],
  826. [-0.5675, 0.7192, 0.4009],
  827. ]
  828. ) # reverse the first dimension for BGR
  829. self.eigval = np.array([0.2175, 0.0188, 0.0045])
  830. def _apply_image(self, image):
  831. if self.scale == 0:
  832. return image
  833. dtype = image.dtype
  834. image = image.astype(np.float32)
  835. alpha = np.random.normal(scale=self.scale, size=3)
  836. image = image + self.eigvec.dot(alpha * self.eigval)
  837. return image.clip(0, 255).astype(dtype)
  838. def _apply_coords(self, coords):
  839. return coords
  840. def _apply_mask(self, mask):
  841. return mask

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台

Contributors (1)