You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YOLOv3 dataset"""
  16. from __future__ import division
  17. import os
  18. import numpy as np
  19. from PIL import Image
  20. from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
  21. import mindspore.dataset as de
  22. from mindspore.mindrecord import FileWriter
  23. import mindspore.dataset.transforms.vision.py_transforms as P
  24. import mindspore.dataset.transforms.vision.c_transforms as C
  25. from config import ConfigYOLOV3ResNet18
  26. iter_cnt = 0
  27. _NUM_BOXES = 50
  28. def preprocess_fn(image, box, is_training):
  29. """Preprocess function for dataset."""
  30. config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
  31. anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2)
  32. do_hsv = False
  33. max_boxes = 20
  34. num_classes = ConfigYOLOV3ResNet18.num_classes
  35. def _rand(a=0., b=1.):
  36. return np.random.rand() * (b - a) + a
  37. def _preprocess_true_boxes(true_boxes, anchors, in_shape=None):
  38. """Get true boxes."""
  39. num_layers = anchors.shape[0] // 3
  40. anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  41. true_boxes = np.array(true_boxes, dtype='float32')
  42. # input_shape = np.array([in_shape, in_shape], dtype='int32')
  43. input_shape = np.array(in_shape, dtype='int32')
  44. boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
  45. boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
  46. true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
  47. true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
  48. grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
  49. y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
  50. 5 + num_classes), dtype='float32') for l in range(num_layers)]
  51. anchors = np.expand_dims(anchors, 0)
  52. anchors_max = anchors / 2.
  53. anchors_min = -anchors_max
  54. valid_mask = boxes_wh[..., 0] >= 1
  55. wh = boxes_wh[valid_mask]
  56. if len(wh) >= 1:
  57. wh = np.expand_dims(wh, -2)
  58. boxes_max = wh / 2.
  59. boxes_min = -boxes_max
  60. intersect_min = np.maximum(boxes_min, anchors_min)
  61. intersect_max = np.minimum(boxes_max, anchors_max)
  62. intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
  63. intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
  64. box_area = wh[..., 0] * wh[..., 1]
  65. anchor_area = anchors[..., 0] * anchors[..., 1]
  66. iou = intersect_area / (box_area + anchor_area - intersect_area)
  67. best_anchor = np.argmax(iou, axis=-1)
  68. for t, n in enumerate(best_anchor):
  69. for l in range(num_layers):
  70. if n in anchor_mask[l]:
  71. i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')
  72. j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')
  73. k = anchor_mask[l].index(n)
  74. c = true_boxes[t, 4].astype('int32')
  75. y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
  76. y_true[l][j, i, k, 4] = 1.
  77. y_true[l][j, i, k, 5 + c] = 1.
  78. pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32)
  79. pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32)
  80. pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32)
  81. mask0 = np.reshape(y_true[0][..., 4:5], [-1])
  82. gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
  83. gt_box0 = gt_box0[mask0 == 1]
  84. pad_gt_box0[:gt_box0.shape[0]] = gt_box0
  85. mask1 = np.reshape(y_true[1][..., 4:5], [-1])
  86. gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
  87. gt_box1 = gt_box1[mask1 == 1]
  88. pad_gt_box1[:gt_box1.shape[0]] = gt_box1
  89. mask2 = np.reshape(y_true[2][..., 4:5], [-1])
  90. gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])
  91. gt_box2 = gt_box2[mask2 == 1]
  92. pad_gt_box2[:gt_box2.shape[0]] = gt_box2
  93. return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
  94. def _infer_data(img_data, input_shape, box):
  95. w, h = img_data.size
  96. input_h, input_w = input_shape
  97. scale = min(float(input_w) / float(w), float(input_h) / float(h))
  98. nw = int(w * scale)
  99. nh = int(h * scale)
  100. img_data = img_data.resize((nw, nh), Image.BICUBIC)
  101. new_image = np.zeros((input_h, input_w, 3), np.float32)
  102. new_image.fill(128)
  103. img_data = np.array(img_data)
  104. if len(img_data.shape) == 2:
  105. img_data = np.expand_dims(img_data, axis=-1)
  106. img_data = np.concatenate([img_data, img_data, img_data], axis=-1)
  107. dh = int((input_h - nh) / 2)
  108. dw = int((input_w - nw) / 2)
  109. new_image[dh:(nh + dh), dw:(nw + dw), :] = img_data
  110. new_image /= 255.
  111. new_image = np.transpose(new_image, (2, 0, 1))
  112. new_image = np.expand_dims(new_image, 0)
  113. return new_image, np.array([h, w], np.float32), box
  114. def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
  115. """Data augmentation function."""
  116. if not isinstance(image, Image.Image):
  117. image = Image.fromarray(image)
  118. iw, ih = image.size
  119. ori_image_shape = np.array([ih, iw], np.int32)
  120. h, w = image_size
  121. if not is_training:
  122. return _infer_data(image, image_size, box)
  123. flip = _rand() < .5
  124. # correct boxes
  125. box_data = np.zeros((max_boxes, 5))
  126. while True:
  127. # Prevent the situation that all boxes are eliminated
  128. new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \
  129. _rand(1 - jitter, 1 + jitter)
  130. scale = _rand(0.25, 2)
  131. if new_ar < 1:
  132. nh = int(scale * h)
  133. nw = int(nh * new_ar)
  134. else:
  135. nw = int(scale * w)
  136. nh = int(nw / new_ar)
  137. dx = int(_rand(0, w - nw))
  138. dy = int(_rand(0, h - nh))
  139. if len(box) >= 1:
  140. t_box = box.copy()
  141. np.random.shuffle(t_box)
  142. t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx
  143. t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy
  144. if flip:
  145. t_box[:, [0, 2]] = w - t_box[:, [2, 0]]
  146. t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
  147. t_box[:, 2][t_box[:, 2] > w] = w
  148. t_box[:, 3][t_box[:, 3] > h] = h
  149. box_w = t_box[:, 2] - t_box[:, 0]
  150. box_h = t_box[:, 3] - t_box[:, 1]
  151. t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
  152. if len(t_box) >= 1:
  153. box = t_box
  154. break
  155. box_data[:len(box)] = box
  156. # resize image
  157. image = image.resize((nw, nh), Image.BICUBIC)
  158. # place image
  159. new_image = Image.new('RGB', (w, h), (128, 128, 128))
  160. new_image.paste(image, (dx, dy))
  161. image = new_image
  162. # flip image or not
  163. if flip:
  164. image = image.transpose(Image.FLIP_LEFT_RIGHT)
  165. # convert image to gray or not
  166. gray = _rand() < .25
  167. if gray:
  168. image = image.convert('L').convert('RGB')
  169. # when the channels of image is 1
  170. image = np.array(image)
  171. if len(image.shape) == 2:
  172. image = np.expand_dims(image, axis=-1)
  173. image = np.concatenate([image, image, image], axis=-1)
  174. # distort image
  175. hue = _rand(-hue, hue)
  176. sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
  177. val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
  178. image_data = image / 255.
  179. if do_hsv:
  180. x = rgb_to_hsv(image_data)
  181. x[..., 0] += hue
  182. x[..., 0][x[..., 0] > 1] -= 1
  183. x[..., 0][x[..., 0] < 0] += 1
  184. x[..., 1] *= sat
  185. x[..., 2] *= val
  186. x[x > 1] = 1
  187. x[x < 0] = 0
  188. image_data = hsv_to_rgb(x) # numpy array, 0 to 1
  189. image_data = image_data.astype(np.float32)
  190. # preprocess bounding boxes
  191. bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
  192. _preprocess_true_boxes(box_data, anchors, image_size)
  193. return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
  194. ori_image_shape, gt_box1, gt_box2, gt_box3
  195. if is_training:
  196. images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
  197. return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
  198. images, shape, anno = _data_aug(image, box, is_training)
  199. return images, shape, anno
  200. def anno_parser(annos_str):
  201. """Parse annotation from string to list."""
  202. annos = []
  203. for anno_str in annos_str:
  204. anno = list(map(int, anno_str.strip().split(',')))
  205. annos.append(anno)
  206. return annos
  207. def filter_valid_data(image_dir, anno_path):
  208. """Filter valid image file, which both in image_dir and anno_path."""
  209. image_files = []
  210. image_anno_dict = {}
  211. if not os.path.isdir(image_dir):
  212. raise RuntimeError("Path given is not valid.")
  213. if not os.path.isfile(anno_path):
  214. raise RuntimeError("Annotation file is not valid.")
  215. with open(anno_path, "rb") as f:
  216. lines = f.readlines()
  217. for line in lines:
  218. line_str = line.decode("utf-8").strip()
  219. line_split = str(line_str).split(' ')
  220. file_name = line_split[0]
  221. if os.path.isfile(os.path.join(image_dir, file_name)):
  222. image_anno_dict[file_name] = anno_parser(line_split[1:])
  223. image_files.append(file_name)
  224. return image_files, image_anno_dict
  225. def data_to_mindrecord_byte_image(image_dir, anno_path, mindrecord_dir, prefix="yolo.mindrecord", file_num=8):
  226. """Create MindRecord file by image_dir and anno_path."""
  227. mindrecord_path = os.path.join(mindrecord_dir, prefix)
  228. writer = FileWriter(mindrecord_path, file_num)
  229. image_files, image_anno_dict = filter_valid_data(image_dir, anno_path)
  230. yolo_json = {
  231. "image": {"type": "bytes"},
  232. "annotation": {"type": "int64", "shape": [-1, 5]},
  233. }
  234. writer.add_schema(yolo_json, "yolo_json")
  235. for image_name in image_files:
  236. image_path = os.path.join(image_dir, image_name)
  237. with open(image_path, 'rb') as f:
  238. img = f.read()
  239. annos = np.array(image_anno_dict[image_name])
  240. row = {"image": img, "annotation": annos}
  241. writer.write_raw_data([row])
  242. writer.commit()
  243. def create_yolo_dataset(mindrecord_dir, batch_size=32, repeat_num=10, device_num=1, rank=0,
  244. is_training=True, num_parallel_workers=8):
  245. """Creatr YOLOv3 dataset with MindDataset."""
  246. ds = de.MindDataset(mindrecord_dir, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank,
  247. num_parallel_workers=num_parallel_workers, shuffle=is_training)
  248. decode = C.Decode()
  249. ds = ds.map(input_columns=["image"], operations=decode)
  250. compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
  251. if is_training:
  252. hwc_to_chw = P.HWC2CHW()
  253. ds = ds.map(input_columns=["image", "annotation"],
  254. output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
  255. columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
  256. operations=compose_map_func, num_parallel_workers=num_parallel_workers)
  257. ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
  258. ds = ds.shuffle(buffer_size=256)
  259. ds = ds.batch(batch_size, drop_remainder=True)
  260. ds = ds.repeat(repeat_num)
  261. else:
  262. ds = ds.map(input_columns=["image", "annotation"],
  263. output_columns=["image", "image_shape", "annotation"],
  264. columns_order=["image", "image_shape", "annotation"],
  265. operations=compose_map_func, num_parallel_workers=num_parallel_workers)
  266. return ds