You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 17 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """FasterRcnn dataset"""
  16. from __future__ import division
  17. import os
  18. import numpy as np
  19. from numpy import random
  20. import mmcv
  21. import mindspore.dataset as de
  22. import mindspore.dataset.transforms.vision.c_transforms as C
  23. from mindspore.mindrecord import FileWriter
  24. from src.config import config
  25. def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
  26. """Calculate the ious between each bbox of bboxes1 and bboxes2.
  27. Args:
  28. bboxes1(ndarray): shape (n, 4)
  29. bboxes2(ndarray): shape (k, 4)
  30. mode(str): iou (intersection over union) or iof (intersection
  31. over foreground)
  32. Returns:
  33. ious(ndarray): shape (n, k)
  34. """
  35. assert mode in ['iou', 'iof']
  36. bboxes1 = bboxes1.astype(np.float32)
  37. bboxes2 = bboxes2.astype(np.float32)
  38. rows = bboxes1.shape[0]
  39. cols = bboxes2.shape[0]
  40. ious = np.zeros((rows, cols), dtype=np.float32)
  41. if rows * cols == 0:
  42. return ious
  43. exchange = False
  44. if bboxes1.shape[0] > bboxes2.shape[0]:
  45. bboxes1, bboxes2 = bboxes2, bboxes1
  46. ious = np.zeros((cols, rows), dtype=np.float32)
  47. exchange = True
  48. area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1)
  49. area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1)
  50. for i in range(bboxes1.shape[0]):
  51. x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
  52. y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
  53. x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
  54. y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
  55. overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
  56. y_end - y_start + 1, 0)
  57. if mode == 'iou':
  58. union = area1[i] + area2 - overlap
  59. else:
  60. union = area1[i] if not exchange else area2
  61. ious[i, :] = overlap / union
  62. if exchange:
  63. ious = ious.T
  64. return ious
  65. class PhotoMetricDistortion:
  66. """Photo Metric Distortion"""
  67. def __init__(self,
  68. brightness_delta=32,
  69. contrast_range=(0.5, 1.5),
  70. saturation_range=(0.5, 1.5),
  71. hue_delta=18):
  72. self.brightness_delta = brightness_delta
  73. self.contrast_lower, self.contrast_upper = contrast_range
  74. self.saturation_lower, self.saturation_upper = saturation_range
  75. self.hue_delta = hue_delta
  76. def __call__(self, img, boxes, labels):
  77. # random brightness
  78. img = img.astype('float32')
  79. if random.randint(2):
  80. delta = random.uniform(-self.brightness_delta,
  81. self.brightness_delta)
  82. img += delta
  83. # mode == 0 --> do random contrast first
  84. # mode == 1 --> do random contrast last
  85. mode = random.randint(2)
  86. if mode == 1:
  87. if random.randint(2):
  88. alpha = random.uniform(self.contrast_lower,
  89. self.contrast_upper)
  90. img *= alpha
  91. # convert color from BGR to HSV
  92. img = mmcv.bgr2hsv(img)
  93. # random saturation
  94. if random.randint(2):
  95. img[..., 1] *= random.uniform(self.saturation_lower,
  96. self.saturation_upper)
  97. # random hue
  98. if random.randint(2):
  99. img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
  100. img[..., 0][img[..., 0] > 360] -= 360
  101. img[..., 0][img[..., 0] < 0] += 360
  102. # convert color from HSV to BGR
  103. img = mmcv.hsv2bgr(img)
  104. # random contrast
  105. if mode == 0:
  106. if random.randint(2):
  107. alpha = random.uniform(self.contrast_lower,
  108. self.contrast_upper)
  109. img *= alpha
  110. # randomly swap channels
  111. if random.randint(2):
  112. img = img[..., random.permutation(3)]
  113. return img, boxes, labels
  114. class Expand:
  115. """expand image"""
  116. def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
  117. if to_rgb:
  118. self.mean = mean[::-1]
  119. else:
  120. self.mean = mean
  121. self.min_ratio, self.max_ratio = ratio_range
  122. def __call__(self, img, boxes, labels):
  123. if random.randint(2):
  124. return img, boxes, labels
  125. h, w, c = img.shape
  126. ratio = random.uniform(self.min_ratio, self.max_ratio)
  127. expand_img = np.full((int(h * ratio), int(w * ratio), c),
  128. self.mean).astype(img.dtype)
  129. left = int(random.uniform(0, w * ratio - w))
  130. top = int(random.uniform(0, h * ratio - h))
  131. expand_img[top:top + h, left:left + w] = img
  132. img = expand_img
  133. boxes += np.tile((left, top), 2)
  134. return img, boxes, labels
  135. def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  136. """rescale operation for image"""
  137. img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
  138. if img_data.shape[0] > config.img_height:
  139. img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True)
  140. scale_factor = scale_factor*scale_factor2
  141. img_shape = np.append(img_shape, scale_factor)
  142. img_shape = np.asarray(img_shape, dtype=np.float32)
  143. gt_bboxes = gt_bboxes * scale_factor
  144. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
  145. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
  146. return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
  147. def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  148. """resize operation for image"""
  149. img_data = img
  150. img_data, w_scale, h_scale = mmcv.imresize(
  151. img_data, (config.img_width, config.img_height), return_scale=True)
  152. scale_factor = np.array(
  153. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  154. img_shape = (config.img_height, config.img_width, 1.0)
  155. img_shape = np.asarray(img_shape, dtype=np.float32)
  156. gt_bboxes = gt_bboxes * scale_factor
  157. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
  158. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
  159. return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
  160. def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num):
  161. """resize operation for image of eval"""
  162. img_data = img
  163. img_data, w_scale, h_scale = mmcv.imresize(
  164. img_data, (config.img_width, config.img_height), return_scale=True)
  165. scale_factor = np.array(
  166. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  167. img_shape = np.append(img_shape, (h_scale, w_scale))
  168. img_shape = np.asarray(img_shape, dtype=np.float32)
  169. gt_bboxes = gt_bboxes * scale_factor
  170. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
  171. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
  172. return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
  173. def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  174. """impad operation for image"""
  175. img_data = mmcv.impad(img, (config.img_height, config.img_width))
  176. img_data = img_data.astype(np.float32)
  177. return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
  178. def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  179. """imnormalize operation for image"""
  180. img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
  181. img_data = img_data.astype(np.float32)
  182. return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
  183. def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  184. """flip operation for image"""
  185. img_data = img
  186. img_data = mmcv.imflip(img_data)
  187. flipped = gt_bboxes.copy()
  188. _, w, _ = img_data.shape
  189. flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
  190. flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
  191. return (img_data, img_shape, flipped, gt_label, gt_num)
  192. def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  193. """transpose operation for image"""
  194. img_data = img.transpose(2, 0, 1).copy()
  195. img_data = img_data.astype(np.float16)
  196. img_shape = img_shape.astype(np.float16)
  197. gt_bboxes = gt_bboxes.astype(np.float16)
  198. gt_label = gt_label.astype(np.int32)
  199. gt_num = gt_num.astype(np.bool)
  200. return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
  201. def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  202. """photo crop operation for image"""
  203. random_photo = PhotoMetricDistortion()
  204. img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
  205. return (img_data, img_shape, gt_bboxes, gt_label, gt_num)
  206. def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num):
  207. """expand operation for image"""
  208. expand = Expand()
  209. img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
  210. return (img, img_shape, gt_bboxes, gt_label, gt_num)
  211. def preprocess_fn(image, box, is_training):
  212. """Preprocess function for dataset."""
  213. def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert):
  214. image_shape = image_shape[:2]
  215. input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
  216. if config.keep_ratio:
  217. input_data = rescale_column(*input_data)
  218. else:
  219. input_data = resize_column_test(*input_data)
  220. input_data = imnormalize_column(*input_data)
  221. output_data = transpose_column(*input_data)
  222. return output_data
  223. def _data_aug(image, box, is_training):
  224. """Data augmentation function."""
  225. image_bgr = image.copy()
  226. image_bgr[:, :, 0] = image[:, :, 2]
  227. image_bgr[:, :, 1] = image[:, :, 1]
  228. image_bgr[:, :, 2] = image[:, :, 0]
  229. image_shape = image_bgr.shape[:2]
  230. gt_box = box[:, :4]
  231. gt_label = box[:, 4]
  232. gt_iscrowd = box[:, 5]
  233. pad_max_number = 128
  234. gt_box_new = np.pad(gt_box, ((0, pad_max_number - box.shape[0]), (0, 0)), mode="constant", constant_values=0)
  235. gt_label_new = np.pad(gt_label, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=-1)
  236. gt_iscrowd_new = np.pad(gt_iscrowd, ((0, pad_max_number - box.shape[0])), mode="constant", constant_values=1)
  237. gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)
  238. if not is_training:
  239. return _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert)
  240. flip = (np.random.rand() < config.flip_ratio)
  241. photo = (np.random.rand() < config.photo_ratio)
  242. expand = (np.random.rand() < config.expand_ratio)
  243. input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert
  244. if expand:
  245. input_data = expand_column(*input_data)
  246. if config.keep_ratio:
  247. input_data = rescale_column(*input_data)
  248. else:
  249. input_data = resize_column(*input_data)
  250. if photo:
  251. input_data = photo_crop_column(*input_data)
  252. input_data = imnormalize_column(*input_data)
  253. if flip:
  254. input_data = flip_column(*input_data)
  255. output_data = transpose_column(*input_data)
  256. return output_data
  257. return _data_aug(image, box, is_training)
  258. def create_coco_label(is_training):
  259. """Get image path and annotation from COCO."""
  260. from pycocotools.coco import COCO
  261. coco_root = config.coco_root
  262. data_type = config.val_data_type
  263. if is_training:
  264. data_type = config.train_data_type
  265. #Classes need to train or test.
  266. train_cls = config.coco_classes
  267. train_cls_dict = {}
  268. for i, cls in enumerate(train_cls):
  269. train_cls_dict[cls] = i
  270. anno_json = os.path.join(coco_root, config.instance_set.format(data_type))
  271. coco = COCO(anno_json)
  272. classs_dict = {}
  273. cat_ids = coco.loadCats(coco.getCatIds())
  274. for cat in cat_ids:
  275. classs_dict[cat["id"]] = cat["name"]
  276. image_ids = coco.getImgIds()
  277. image_files = []
  278. image_anno_dict = {}
  279. for img_id in image_ids:
  280. image_info = coco.loadImgs(img_id)
  281. file_name = image_info[0]["file_name"]
  282. anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
  283. anno = coco.loadAnns(anno_ids)
  284. image_path = os.path.join(coco_root, data_type, file_name)
  285. annos = []
  286. for label in anno:
  287. bbox = label["bbox"]
  288. class_name = classs_dict[label["category_id"]]
  289. if class_name in train_cls:
  290. x1, x2 = bbox[0], bbox[0] + bbox[2]
  291. y1, y2 = bbox[1], bbox[1] + bbox[3]
  292. annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])])
  293. image_files.append(image_path)
  294. if annos:
  295. image_anno_dict[image_path] = np.array(annos)
  296. else:
  297. image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
  298. return image_files, image_anno_dict
  299. def anno_parser(annos_str):
  300. """Parse annotation from string to list."""
  301. annos = []
  302. for anno_str in annos_str:
  303. anno = list(map(int, anno_str.strip().split(',')))
  304. annos.append(anno)
  305. return annos
  306. def filter_valid_data(image_dir, anno_path):
  307. """Filter valid image file, which both in image_dir and anno_path."""
  308. image_files = []
  309. image_anno_dict = {}
  310. if not os.path.isdir(image_dir):
  311. raise RuntimeError("Path given is not valid.")
  312. if not os.path.isfile(anno_path):
  313. raise RuntimeError("Annotation file is not valid.")
  314. with open(anno_path, "rb") as f:
  315. lines = f.readlines()
  316. for line in lines:
  317. line_str = line.decode("utf-8").strip()
  318. line_split = str(line_str).split(' ')
  319. file_name = line_split[0]
  320. image_path = os.path.join(image_dir, file_name)
  321. if os.path.isfile(image_path):
  322. image_anno_dict[image_path] = anno_parser(line_split[1:])
  323. image_files.append(image_path)
  324. return image_files, image_anno_dict
  325. def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="fasterrcnn.mindrecord", file_num=8):
  326. """Create MindRecord file."""
  327. mindrecord_dir = config.mindrecord_dir
  328. mindrecord_path = os.path.join(mindrecord_dir, prefix)
  329. writer = FileWriter(mindrecord_path, file_num)
  330. if dataset == "coco":
  331. image_files, image_anno_dict = create_coco_label(is_training)
  332. else:
  333. image_files, image_anno_dict = filter_valid_data(config.IMAGE_DIR, config.ANNO_PATH)
  334. fasterrcnn_json = {
  335. "image": {"type": "bytes"},
  336. "annotation": {"type": "int32", "shape": [-1, 6]},
  337. }
  338. writer.add_schema(fasterrcnn_json, "fasterrcnn_json")
  339. for image_name in image_files:
  340. with open(image_name, 'rb') as f:
  341. img = f.read()
  342. annos = np.array(image_anno_dict[image_name], dtype=np.int32)
  343. row = {"image": img, "annotation": annos}
  344. writer.write_raw_data([row])
  345. writer.commit()
  346. def create_fasterrcnn_dataset(mindrecord_file, batch_size=2, repeat_num=12, device_num=1, rank_id=0,
  347. is_training=True, num_parallel_workers=8):
  348. """Creatr FasterRcnn dataset with MindDataset."""
  349. ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,
  350. num_parallel_workers=num_parallel_workers, shuffle=is_training)
  351. decode = C.Decode()
  352. ds = ds.map(input_columns=["image"], operations=decode)
  353. compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
  354. if is_training:
  355. ds = ds.map(input_columns=["image", "annotation"],
  356. output_columns=["image", "image_shape", "box", "label", "valid_num"],
  357. columns_order=["image", "image_shape", "box", "label", "valid_num"],
  358. operations=compose_map_func, python_multiprocessing=True, num_parallel_workers=num_parallel_workers)
  359. ds = ds.batch(batch_size, drop_remainder=True)
  360. ds = ds.repeat(repeat_num)
  361. else:
  362. ds = ds.map(input_columns=["image", "annotation"],
  363. output_columns=["image", "image_shape", "box", "label", "valid_num"],
  364. columns_order=["image", "image_shape", "box", "label", "valid_num"],
  365. operations=compose_map_func,
  366. num_parallel_workers=num_parallel_workers)
  367. ds = ds.batch(batch_size, drop_remainder=True)
  368. ds = ds.repeat(repeat_num)
  369. return ds