You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """MaskRcnn dataset"""
  16. from __future__ import division
  17. import os
  18. import numpy as np
  19. from numpy import random
  20. import mmcv
  21. import mindspore.dataset as de
  22. import mindspore.dataset.transforms.vision.c_transforms as C
  23. from mindspore.mindrecord import FileWriter
  24. from src.config import config
  25. import cv2
  26. def bbox_overlaps(bboxes1, bboxes2, mode='iou'):
  27. """Calculate the ious between each bbox of bboxes1 and bboxes2.
  28. Args:
  29. bboxes1(ndarray): shape (n, 4)
  30. bboxes2(ndarray): shape (k, 4)
  31. mode(str): iou (intersection over union) or iof (intersection
  32. over foreground)
  33. Returns:
  34. ious(ndarray): shape (n, k)
  35. """
  36. assert mode in ['iou', 'iof']
  37. bboxes1 = bboxes1.astype(np.float32)
  38. bboxes2 = bboxes2.astype(np.float32)
  39. rows = bboxes1.shape[0]
  40. cols = bboxes2.shape[0]
  41. ious = np.zeros((rows, cols), dtype=np.float32)
  42. if rows * cols == 0:
  43. return ious
  44. exchange = False
  45. if bboxes1.shape[0] > bboxes2.shape[0]:
  46. bboxes1, bboxes2 = bboxes2, bboxes1
  47. ious = np.zeros((cols, rows), dtype=np.float32)
  48. exchange = True
  49. area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1)
  50. area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1)
  51. for i in range(bboxes1.shape[0]):
  52. x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
  53. y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
  54. x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
  55. y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
  56. overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum(
  57. y_end - y_start + 1, 0)
  58. if mode == 'iou':
  59. union = area1[i] + area2 - overlap
  60. else:
  61. union = area1[i] if not exchange else area2
  62. ious[i, :] = overlap / union
  63. if exchange:
  64. ious = ious.T
  65. return ious
  66. class PhotoMetricDistortion:
  67. """Photo Metric Distortion"""
  68. def __init__(self,
  69. brightness_delta=32,
  70. contrast_range=(0.5, 1.5),
  71. saturation_range=(0.5, 1.5),
  72. hue_delta=18):
  73. self.brightness_delta = brightness_delta
  74. self.contrast_lower, self.contrast_upper = contrast_range
  75. self.saturation_lower, self.saturation_upper = saturation_range
  76. self.hue_delta = hue_delta
  77. def __call__(self, img, boxes, labels):
  78. # random brightness
  79. img = img.astype('float32')
  80. if random.randint(2):
  81. delta = random.uniform(-self.brightness_delta,
  82. self.brightness_delta)
  83. img += delta
  84. # mode == 0 --> do random contrast first
  85. # mode == 1 --> do random contrast last
  86. mode = random.randint(2)
  87. if mode == 1:
  88. if random.randint(2):
  89. alpha = random.uniform(self.contrast_lower,
  90. self.contrast_upper)
  91. img *= alpha
  92. # convert color from BGR to HSV
  93. img = mmcv.bgr2hsv(img)
  94. # random saturation
  95. if random.randint(2):
  96. img[..., 1] *= random.uniform(self.saturation_lower,
  97. self.saturation_upper)
  98. # random hue
  99. if random.randint(2):
  100. img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
  101. img[..., 0][img[..., 0] > 360] -= 360
  102. img[..., 0][img[..., 0] < 0] += 360
  103. # convert color from HSV to BGR
  104. img = mmcv.hsv2bgr(img)
  105. # random contrast
  106. if mode == 0:
  107. if random.randint(2):
  108. alpha = random.uniform(self.contrast_lower,
  109. self.contrast_upper)
  110. img *= alpha
  111. # randomly swap channels
  112. if random.randint(2):
  113. img = img[..., random.permutation(3)]
  114. return img, boxes, labels
  115. class Expand:
  116. """expand image"""
  117. def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
  118. if to_rgb:
  119. self.mean = mean[::-1]
  120. else:
  121. self.mean = mean
  122. self.min_ratio, self.max_ratio = ratio_range
  123. def __call__(self, img, boxes, labels, mask):
  124. if random.randint(2):
  125. return img, boxes, labels, mask
  126. h, w, c = img.shape
  127. ratio = random.uniform(self.min_ratio, self.max_ratio)
  128. expand_img = np.full((int(h * ratio), int(w * ratio), c),
  129. self.mean).astype(img.dtype)
  130. left = int(random.uniform(0, w * ratio - w))
  131. top = int(random.uniform(0, h * ratio - h))
  132. expand_img[top:top + h, left:left + w] = img
  133. img = expand_img
  134. boxes += np.tile((left, top), 2)
  135. mask_count, mask_h, mask_w = mask.shape
  136. expand_mask = np.zeros((mask_count, int(mask_h * ratio), int(mask_w * ratio))).astype(mask.dtype)
  137. expand_mask[:, top:top + h, left:left + w] = mask
  138. mask = expand_mask
  139. return img, boxes, labels, mask
  140. def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  141. """rescale operation for image"""
  142. img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
  143. if img_data.shape[0] > config.img_height:
  144. img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True)
  145. scale_factor = scale_factor*scale_factor2
  146. gt_bboxes = gt_bboxes * scale_factor
  147. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
  148. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
  149. gt_mask_data = np.array([
  150. mmcv.imrescale(mask, scale_factor, interpolation='nearest')
  151. for mask in gt_mask
  152. ])
  153. pad_h = config.img_height - img_data.shape[0]
  154. pad_w = config.img_width - img_data.shape[1]
  155. assert ((pad_h >= 0) and (pad_w >= 0))
  156. pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype)
  157. pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data
  158. mask_count, mask_h, mask_w = gt_mask_data.shape
  159. pad_mask = np.zeros((mask_count, config.img_height, config.img_width)).astype(gt_mask_data.dtype)
  160. pad_mask[:, 0:mask_h, 0:mask_w] = gt_mask_data
  161. img_shape = (config.img_height, config.img_width, 1.0)
  162. img_shape = np.asarray(img_shape, dtype=np.float32)
  163. return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, pad_mask)
  164. def rescale_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  165. """rescale operation for image of eval"""
  166. img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
  167. if img_data.shape[0] > config.img_height:
  168. img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_height), return_scale=True)
  169. scale_factor = scale_factor*scale_factor2
  170. pad_h = config.img_height - img_data.shape[0]
  171. pad_w = config.img_width - img_data.shape[1]
  172. assert ((pad_h >= 0) and (pad_w >= 0))
  173. pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype)
  174. pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data
  175. img_shape = np.append(img_shape, (scale_factor, scale_factor))
  176. img_shape = np.asarray(img_shape, dtype=np.float32)
  177. return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask)
  178. def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  179. """resize operation for image"""
  180. img_data = img
  181. img_data, w_scale, h_scale = mmcv.imresize(
  182. img_data, (config.img_width, config.img_height), return_scale=True)
  183. scale_factor = np.array(
  184. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  185. img_shape = (config.img_height, config.img_width, 1.0)
  186. img_shape = np.asarray(img_shape, dtype=np.float32)
  187. gt_bboxes = gt_bboxes * scale_factor
  188. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) # x1, x2 [0, W-1]
  189. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) # y1, y2 [0, H-1]
  190. gt_mask_data = np.array([
  191. mmcv.imresize(mask, (config.img_width, config.img_height), interpolation='nearest')
  192. for mask in gt_mask
  193. ])
  194. return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask_data)
  195. def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  196. """resize operation for image of eval"""
  197. img_data = img
  198. img_data, w_scale, h_scale = mmcv.imresize(
  199. img_data, (config.img_width, config.img_height), return_scale=True)
  200. img_shape = np.append(img_shape, (h_scale, w_scale))
  201. img_shape = np.asarray(img_shape, dtype=np.float32)
  202. return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask)
  203. def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  204. """impad operation for image"""
  205. img_data = mmcv.impad(img, (config.img_height, config.img_width))
  206. img_data = img_data.astype(np.float32)
  207. return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask)
  208. def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  209. """imnormalize operation for image"""
  210. img_data = mmcv.imnormalize(img, [123.675, 116.28, 103.53], [58.395, 57.12, 57.375], True)
  211. img_data = img_data.astype(np.float32)
  212. return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask)
  213. def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  214. """flip operation for image"""
  215. img_data = img
  216. img_data = mmcv.imflip(img_data)
  217. flipped = gt_bboxes.copy()
  218. _, w, _ = img_data.shape
  219. flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 # x1 = W-x2-1
  220. flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 # x2 = W-x1-1
  221. gt_mask_data = np.array([mask[:, ::-1] for mask in gt_mask])
  222. return (img_data, img_shape, flipped, gt_label, gt_num, gt_mask_data)
  223. def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  224. """transpose operation for image"""
  225. img_data = img.transpose(2, 0, 1).copy()
  226. img_data = img_data.astype(np.float16)
  227. img_shape = img_shape.astype(np.float16)
  228. gt_bboxes = gt_bboxes.astype(np.float16)
  229. gt_label = gt_label.astype(np.int32)
  230. gt_num = gt_num.astype(np.bool)
  231. gt_mask_data = gt_mask.astype(np.bool)
  232. return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask_data)
  233. def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  234. """photo crop operation for image"""
  235. random_photo = PhotoMetricDistortion()
  236. img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
  237. return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask)
  238. def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask):
  239. """expand operation for image"""
  240. expand = Expand()
  241. img, gt_bboxes, gt_label, gt_mask = expand(img, gt_bboxes, gt_label, gt_mask)
  242. return (img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask)
  243. def pad_to_max(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask, instance_count):
  244. pad_max_number = config.max_instance_count
  245. gt_box_new = np.pad(gt_bboxes, ((0, pad_max_number - instance_count), (0, 0)), mode="constant", constant_values=0)
  246. gt_label_new = np.pad(gt_label, ((0, pad_max_number - instance_count)), mode="constant", constant_values=-1)
  247. gt_iscrowd_new = np.pad(gt_num, ((0, pad_max_number - instance_count)), mode="constant", constant_values=1)
  248. gt_iscrowd_new_revert = ~(gt_iscrowd_new.astype(np.bool))
  249. gt_mask_new = np.pad(gt_mask, ((0, pad_max_number - instance_count), (0, 0), (0, 0)), mode="constant",
  250. constant_values=0)
  251. return img, img_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert, gt_mask_new
  252. def preprocess_fn(image, box, mask, mask_shape, is_training):
  253. """Preprocess function for dataset."""
  254. def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert,
  255. gt_mask_new, instance_count):
  256. image_shape = image_shape[:2]
  257. input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert, gt_mask_new
  258. if config.keep_ratio:
  259. input_data = rescale_column_test(*input_data)
  260. else:
  261. input_data = resize_column_test(*input_data)
  262. input_data = imnormalize_column(*input_data)
  263. input_data = pad_to_max(*input_data, instance_count)
  264. output_data = transpose_column(*input_data)
  265. return output_data
  266. def _data_aug(image, box, mask, mask_shape, is_training):
  267. """Data augmentation function."""
  268. image_bgr = image.copy()
  269. image_bgr[:, :, 0] = image[:, :, 2]
  270. image_bgr[:, :, 1] = image[:, :, 1]
  271. image_bgr[:, :, 2] = image[:, :, 0]
  272. image_shape = image_bgr.shape[:2]
  273. instance_count = box.shape[0]
  274. gt_box = box[:, :4]
  275. gt_label = box[:, 4]
  276. gt_iscrowd = box[:, 5]
  277. gt_mask = mask.copy()
  278. n, h, w = mask_shape
  279. gt_mask = gt_mask.reshape(n, h, w)
  280. assert n == box.shape[0]
  281. if not is_training:
  282. return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_iscrowd, gt_mask, instance_count)
  283. flip = (np.random.rand() < config.flip_ratio)
  284. expand = (np.random.rand() < config.expand_ratio)
  285. input_data = image_bgr, image_shape, gt_box, gt_label, gt_iscrowd, gt_mask
  286. if expand:
  287. input_data = expand_column(*input_data)
  288. if config.keep_ratio:
  289. input_data = rescale_column(*input_data)
  290. else:
  291. input_data = resize_column(*input_data)
  292. input_data = imnormalize_column(*input_data)
  293. if flip:
  294. input_data = flip_column(*input_data)
  295. input_data = pad_to_max(*input_data, instance_count)
  296. output_data = transpose_column(*input_data)
  297. return output_data
  298. return _data_aug(image, box, mask, mask_shape, is_training)
  299. def annToMask(ann, height, width):
  300. """Convert annotation to RLE and then to binary mask."""
  301. from pycocotools import mask as maskHelper
  302. segm = ann['segmentation']
  303. if isinstance(segm, list):
  304. rles = maskHelper.frPyObjects(segm, height, width)
  305. rle = maskHelper.merge(rles)
  306. elif isinstance(segm['counts'], list):
  307. rle = maskHelper.frPyObjects(segm, height, width)
  308. else:
  309. rle = ann['segmentation']
  310. m = maskHelper.decode(rle)
  311. return m
  312. def create_coco_label(is_training):
  313. """Get image path and annotation from COCO."""
  314. from pycocotools.coco import COCO
  315. coco_root = config.coco_root
  316. data_type = config.val_data_type
  317. if is_training:
  318. data_type = config.train_data_type
  319. #Classes need to train or test.
  320. train_cls = config.coco_classes
  321. train_cls_dict = {}
  322. for i, cls in enumerate(train_cls):
  323. train_cls_dict[cls] = i
  324. anno_json = os.path.join(coco_root, config.instance_set.format(data_type))
  325. coco = COCO(anno_json)
  326. classs_dict = {}
  327. cat_ids = coco.loadCats(coco.getCatIds())
  328. for cat in cat_ids:
  329. classs_dict[cat["id"]] = cat["name"]
  330. image_ids = coco.getImgIds()
  331. image_files = []
  332. image_anno_dict = {}
  333. masks = {}
  334. masks_shape = {}
  335. for img_id in image_ids:
  336. image_info = coco.loadImgs(img_id)
  337. file_name = image_info[0]["file_name"]
  338. anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None)
  339. anno = coco.loadAnns(anno_ids)
  340. image_path = os.path.join(coco_root, data_type, file_name)
  341. annos = []
  342. instance_masks = []
  343. image_height = coco.imgs[img_id]["height"]
  344. image_width = coco.imgs[img_id]["width"]
  345. print("image file name: ", file_name)
  346. if not is_training:
  347. image_files.append(image_path)
  348. image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
  349. masks[image_path] = np.zeros([1, 1, 1], dtype=np.bool).tobytes()
  350. masks_shape[image_path] = np.array([1, 1, 1], dtype=np.int32)
  351. else:
  352. for label in anno:
  353. bbox = label["bbox"]
  354. class_name = classs_dict[label["category_id"]]
  355. if class_name in train_cls:
  356. # get coco mask
  357. m = annToMask(label, image_height, image_width)
  358. if m.max() < 1:
  359. print("all black mask!!!!")
  360. continue
  361. # Resize mask for the crowd
  362. if label['iscrowd'] and (m.shape[0] != image_height or m.shape[1] != image_width):
  363. m = np.ones([image_height, image_width], dtype=np.bool)
  364. instance_masks.append(m)
  365. # get coco bbox
  366. x1, x2 = bbox[0], bbox[0] + bbox[2]
  367. y1, y2 = bbox[1], bbox[1] + bbox[3]
  368. annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])])
  369. else:
  370. print("not in classes: ", class_name)
  371. image_files.append(image_path)
  372. if annos:
  373. image_anno_dict[image_path] = np.array(annos)
  374. instance_masks = np.stack(instance_masks, axis=0).astype(np.bool)
  375. masks[image_path] = np.array(instance_masks).tobytes()
  376. masks_shape[image_path] = np.array(instance_masks.shape, dtype=np.int32)
  377. else:
  378. print("no annotations for image ", file_name)
  379. image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1])
  380. masks[image_path] = np.zeros([1, image_height, image_width], dtype=np.bool).tobytes()
  381. masks_shape[image_path] = np.array([1, image_height, image_width], dtype=np.int32)
  382. return image_files, image_anno_dict, masks, masks_shape
  383. def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="maskrcnn.mindrecord", file_num=8):
  384. """Create MindRecord file."""
  385. mindrecord_dir = config.mindrecord_dir
  386. mindrecord_path = os.path.join(mindrecord_dir, prefix)
  387. writer = FileWriter(mindrecord_path, file_num)
  388. if dataset == "coco":
  389. image_files, image_anno_dict, masks, masks_shape = create_coco_label(is_training)
  390. else:
  391. print("Error unsupport other dataset")
  392. return
  393. maskrcnn_json = {
  394. "image": {"type": "bytes"},
  395. "annotation": {"type": "int32", "shape": [-1, 6]},
  396. "mask": {"type": "bytes"},
  397. "mask_shape": {"type": "int32", "shape": [-1]},
  398. }
  399. writer.add_schema(maskrcnn_json, "maskrcnn_json")
  400. for image_name in image_files:
  401. with open(image_name, 'rb') as f:
  402. img = f.read()
  403. annos = np.array(image_anno_dict[image_name], dtype=np.int32)
  404. mask = masks[image_name]
  405. mask_shape = masks_shape[image_name]
  406. row = {"image": img, "annotation": annos, "mask": mask, "mask_shape": mask_shape}
  407. writer.write_raw_data([row])
  408. writer.commit()
  409. def create_maskrcnn_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0,
  410. is_training=True, num_parallel_workers=8):
  411. """Create MaskRcnn dataset with MindDataset."""
  412. cv2.setNumThreads(0)
  413. de.config.set_prefetch_size(8)
  414. ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation", "mask", "mask_shape"],
  415. num_shards=device_num, shard_id=rank_id,
  416. num_parallel_workers=4, shuffle=is_training)
  417. decode = C.Decode()
  418. ds = ds.map(input_columns=["image"], operations=decode)
  419. compose_map_func = (lambda image, annotation, mask, mask_shape:
  420. preprocess_fn(image, annotation, mask, mask_shape, is_training))
  421. if is_training:
  422. ds = ds.map(input_columns=["image", "annotation", "mask", "mask_shape"],
  423. output_columns=["image", "image_shape", "box", "label", "valid_num", "mask"],
  424. columns_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
  425. operations=compose_map_func,
  426. python_multiprocessing=False,
  427. num_parallel_workers=num_parallel_workers)
  428. ds = ds.batch(batch_size, drop_remainder=True)
  429. else:
  430. ds = ds.map(input_columns=["image", "annotation", "mask", "mask_shape"],
  431. output_columns=["image", "image_shape", "box", "label", "valid_num", "mask"],
  432. columns_order=["image", "image_shape", "box", "label", "valid_num", "mask"],
  433. operations=compose_map_func,
  434. num_parallel_workers=num_parallel_workers)
  435. ds = ds.batch(batch_size, drop_remainder=True)
  436. return ds