You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YOLOv3 dataset"""
  16. from __future__ import division
  17. import abc
  18. import io
  19. import os
  20. import math
  21. import json
  22. import numpy as np
  23. from PIL import Image
  24. from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
  25. import mindspore.dataset as de
  26. import mindspore.dataset.transforms.vision.py_transforms as P
  27. from config import ConfigYOLOV3ResNet18
  28. iter_cnt = 0
  29. _NUM_BOXES = 50
  30. def preprocess_fn(image, box, is_training):
  31. """Preprocess function for dataset."""
  32. config_anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
  33. anchors = np.array([float(x) for x in config_anchors]).reshape(-1, 2)
  34. do_hsv = False
  35. max_boxes = 20
  36. num_classes = ConfigYOLOV3ResNet18.num_classes
  37. def _rand(a=0., b=1.):
  38. return np.random.rand() * (b - a) + a
  39. def _preprocess_true_boxes(true_boxes, anchors, in_shape=None):
  40. """Get true boxes."""
  41. num_layers = anchors.shape[0] // 3
  42. anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  43. true_boxes = np.array(true_boxes, dtype='float32')
  44. # input_shape = np.array([in_shape, in_shape], dtype='int32')
  45. input_shape = np.array(in_shape, dtype='int32')
  46. boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
  47. boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
  48. true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
  49. true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
  50. grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
  51. y_true = [np.zeros((grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]),
  52. 5 + num_classes), dtype='float32') for l in range(num_layers)]
  53. anchors = np.expand_dims(anchors, 0)
  54. anchors_max = anchors / 2.
  55. anchors_min = -anchors_max
  56. valid_mask = boxes_wh[..., 0] >= 1
  57. wh = boxes_wh[valid_mask]
  58. if len(wh) >= 1:
  59. wh = np.expand_dims(wh, -2)
  60. boxes_max = wh / 2.
  61. boxes_min = -boxes_max
  62. intersect_min = np.maximum(boxes_min, anchors_min)
  63. intersect_max = np.minimum(boxes_max, anchors_max)
  64. intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
  65. intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
  66. box_area = wh[..., 0] * wh[..., 1]
  67. anchor_area = anchors[..., 0] * anchors[..., 1]
  68. iou = intersect_area / (box_area + anchor_area - intersect_area)
  69. best_anchor = np.argmax(iou, axis=-1)
  70. for t, n in enumerate(best_anchor):
  71. for l in range(num_layers):
  72. if n in anchor_mask[l]:
  73. i = np.floor(true_boxes[t, 0] * grid_shapes[l][1]).astype('int32')
  74. j = np.floor(true_boxes[t, 1] * grid_shapes[l][0]).astype('int32')
  75. k = anchor_mask[l].index(n)
  76. c = true_boxes[t, 4].astype('int32')
  77. y_true[l][j, i, k, 0:4] = true_boxes[t, 0:4]
  78. y_true[l][j, i, k, 4] = 1.
  79. y_true[l][j, i, k, 5 + c] = 1.
  80. pad_gt_box0 = np.zeros(shape=[50, 4], dtype=np.float32)
  81. pad_gt_box1 = np.zeros(shape=[50, 4], dtype=np.float32)
  82. pad_gt_box2 = np.zeros(shape=[50, 4], dtype=np.float32)
  83. mask0 = np.reshape(y_true[0][..., 4:5], [-1])
  84. gt_box0 = np.reshape(y_true[0][..., 0:4], [-1, 4])
  85. gt_box0 = gt_box0[mask0 == 1]
  86. pad_gt_box0[:gt_box0.shape[0]] = gt_box0
  87. mask1 = np.reshape(y_true[1][..., 4:5], [-1])
  88. gt_box1 = np.reshape(y_true[1][..., 0:4], [-1, 4])
  89. gt_box1 = gt_box1[mask1 == 1]
  90. pad_gt_box1[:gt_box1.shape[0]] = gt_box1
  91. mask2 = np.reshape(y_true[2][..., 4:5], [-1])
  92. gt_box2 = np.reshape(y_true[2][..., 0:4], [-1, 4])
  93. gt_box2 = gt_box2[mask2 == 1]
  94. pad_gt_box2[:gt_box2.shape[0]] = gt_box2
  95. return y_true[0], y_true[1], y_true[2], pad_gt_box0, pad_gt_box1, pad_gt_box2
  96. def _data_aug(image, box, is_training, jitter=0.3, hue=0.1, sat=1.5, val=1.5, image_size=(352, 640)):
  97. """Data augmentation function."""
  98. if not isinstance(image, Image.Image):
  99. image = Image.fromarray(image)
  100. iw, ih = image.size
  101. ori_image_shape = np.array([ih, iw], np.int32)
  102. h, w = image_size
  103. if not is_training:
  104. image = image.resize((w, h), Image.BICUBIC)
  105. image_data = np.array(image) / 255.
  106. if len(image_data.shape) == 2:
  107. image_data = np.expand_dims(image_data, axis=-1)
  108. image_data = np.concatenate([image_data, image_data, image_data], axis=-1)
  109. image_data = image_data.astype(np.float32)
  110. # correct boxes
  111. box_data = np.zeros((max_boxes, 5))
  112. if len(box) >= 1:
  113. np.random.shuffle(box)
  114. if len(box) > max_boxes:
  115. box = box[:max_boxes]
  116. # xmin ymin xmax ymax
  117. box[:, [0, 2]] = box[:, [0, 2]] * float(w) / float(iw)
  118. box[:, [1, 3]] = box[:, [1, 3]] * float(h) / float(ih)
  119. box_data[:len(box)] = box
  120. else:
  121. image_data, box_data = None, None
  122. # preprocess bounding boxes
  123. bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
  124. _preprocess_true_boxes(box_data, anchors, image_size)
  125. return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
  126. ori_image_shape, gt_box1, gt_box2, gt_box3
  127. flip = _rand() < .5
  128. # correct boxes
  129. box_data = np.zeros((max_boxes, 5))
  130. while True:
  131. # Prevent the situation that all boxes are eliminated
  132. new_ar = float(w) / float(h) * _rand(1 - jitter, 1 + jitter) / \
  133. _rand(1 - jitter, 1 + jitter)
  134. scale = _rand(0.25, 2)
  135. if new_ar < 1:
  136. nh = int(scale * h)
  137. nw = int(nh * new_ar)
  138. else:
  139. nw = int(scale * w)
  140. nh = int(nw / new_ar)
  141. dx = int(_rand(0, w - nw))
  142. dy = int(_rand(0, h - nh))
  143. if len(box) >= 1:
  144. t_box = box.copy()
  145. np.random.shuffle(t_box)
  146. t_box[:, [0, 2]] = t_box[:, [0, 2]] * float(nw) / float(iw) + dx
  147. t_box[:, [1, 3]] = t_box[:, [1, 3]] * float(nh) / float(ih) + dy
  148. if flip:
  149. t_box[:, [0, 2]] = w - t_box[:, [2, 0]]
  150. t_box[:, 0:2][t_box[:, 0:2] < 0] = 0
  151. t_box[:, 2][t_box[:, 2] > w] = w
  152. t_box[:, 3][t_box[:, 3] > h] = h
  153. box_w = t_box[:, 2] - t_box[:, 0]
  154. box_h = t_box[:, 3] - t_box[:, 1]
  155. t_box = t_box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
  156. if len(t_box) >= 1:
  157. box = t_box
  158. break
  159. box_data[:len(box)] = box
  160. # resize image
  161. image = image.resize((nw, nh), Image.BICUBIC)
  162. # place image
  163. new_image = Image.new('RGB', (w, h), (128, 128, 128))
  164. new_image.paste(image, (dx, dy))
  165. image = new_image
  166. # flip image or not
  167. if flip:
  168. image = image.transpose(Image.FLIP_LEFT_RIGHT)
  169. # convert image to gray or not
  170. gray = _rand() < .25
  171. if gray:
  172. image = image.convert('L').convert('RGB')
  173. # when the channels of image is 1
  174. image = np.array(image)
  175. if len(image.shape) == 2:
  176. image = np.expand_dims(image, axis=-1)
  177. image = np.concatenate([image, image, image], axis=-1)
  178. # distort image
  179. hue = _rand(-hue, hue)
  180. sat = _rand(1, sat) if _rand() < .5 else 1 / _rand(1, sat)
  181. val = _rand(1, val) if _rand() < .5 else 1 / _rand(1, val)
  182. image_data = image / 255.
  183. if do_hsv:
  184. x = rgb_to_hsv(image_data)
  185. x[..., 0] += hue
  186. x[..., 0][x[..., 0] > 1] -= 1
  187. x[..., 0][x[..., 0] < 0] += 1
  188. x[..., 1] *= sat
  189. x[..., 2] *= val
  190. x[x > 1] = 1
  191. x[x < 0] = 0
  192. image_data = hsv_to_rgb(x) # numpy array, 0 to 1
  193. image_data = image_data.astype(np.float32)
  194. # preprocess bounding boxes
  195. bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \
  196. _preprocess_true_boxes(box_data, anchors, image_size)
  197. return image_data, bbox_true_1, bbox_true_2, bbox_true_3, \
  198. ori_image_shape, gt_box1, gt_box2, gt_box3
  199. images, bbox_1, bbox_2, bbox_3, _, gt_box1, gt_box2, gt_box3 = _data_aug(image, box, is_training)
  200. return images, bbox_1, bbox_2, bbox_3, gt_box1, gt_box2, gt_box3
  201. def anno_parser(annos_str):
  202. """Annotation parser."""
  203. annos = []
  204. for anno_str in annos_str:
  205. anno = list(map(int, anno_str.strip().split(',')))
  206. annos.append(anno)
  207. return annos
  208. def expand_path(path):
  209. """Get file list from path."""
  210. files = []
  211. if os.path.isdir(path):
  212. for file in os.listdir(path):
  213. if os.path.isfile(os.path.join(path, file)):
  214. files.append(file)
  215. else:
  216. raise RuntimeError("Path given is not valid.")
  217. return files
  218. def read_image(img_path):
  219. """Read image with PIL."""
  220. with open(img_path, "rb") as f:
  221. img = f.read()
  222. data = io.BytesIO(img)
  223. img = Image.open(data)
  224. return np.array(img)
  225. class BaseDataset():
  226. """BaseDataset for GeneratorDataset iterator."""
  227. def __init__(self, image_dir, anno_path):
  228. self.image_dir = image_dir
  229. self.anno_path = anno_path
  230. self.cur_index = 0
  231. self.samples = []
  232. self.image_anno_dict = {}
  233. self._load_samples()
  234. def __getitem__(self, item):
  235. sample = self.samples[item]
  236. return self._next_data(sample, self.image_dir, self.image_anno_dict)
  237. def __len__(self):
  238. return len(self.samples)
  239. @staticmethod
  240. def _next_data(sample, image_dir, image_anno_dict):
  241. """Get next data."""
  242. image = read_image(os.path.join(image_dir, sample))
  243. annos = image_anno_dict[sample]
  244. return [np.array(image), np.array(annos)]
  245. @abc.abstractmethod
  246. def _load_samples(self):
  247. """Base load samples."""
  248. class YoloDataset(BaseDataset):
  249. """YoloDataset for GeneratorDataset iterator."""
  250. def _load_samples(self):
  251. """Load samples."""
  252. image_files_raw = expand_path(self.image_dir)
  253. self.samples = self._filter_valid_data(self.anno_path, image_files_raw)
  254. self.dataset_size = len(self.samples)
  255. if self.dataset_size == 0:
  256. raise RuntimeError("Valid dataset is none!")
  257. def _filter_valid_data(self, anno_path, image_files_raw):
  258. """Filter valid data."""
  259. image_files = []
  260. anno_dict = {}
  261. print("Start filter valid data.")
  262. with open(anno_path, "rb") as f:
  263. lines = f.readlines()
  264. for line in lines:
  265. line_str = line.decode("utf-8")
  266. line_split = str(line_str).split(' ')
  267. anno_dict[line_split[0].split("/")[-1]] = line_split[1:]
  268. anno_set = set(anno_dict.keys())
  269. image_set = set(image_files_raw)
  270. for image_file in (anno_set & image_set):
  271. image_files.append(image_file)
  272. self.image_anno_dict[image_file] = anno_parser(anno_dict[image_file])
  273. image_files.sort()
  274. print("Filter valid data done!")
  275. return image_files
  276. class DistributedSampler():
  277. """DistributedSampler for YOLOv3"""
  278. def __init__(self, dataset_size, batch_size, num_replicas=None, rank=None, shuffle=True):
  279. if num_replicas is None:
  280. num_replicas = 1
  281. if rank is None:
  282. rank = 0
  283. self.dataset_size = dataset_size
  284. self.num_replicas = num_replicas
  285. self.rank = rank % num_replicas
  286. self.epoch = 0
  287. self.num_samples = max(batch_size, int(math.ceil(dataset_size * 1.0 / self.num_replicas)))
  288. self.total_size = self.num_samples * self.num_replicas
  289. self.shuffle = shuffle
  290. def __iter__(self):
  291. # deterministically shuffle based on epoch
  292. if self.shuffle:
  293. indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
  294. indices = indices.tolist()
  295. else:
  296. indices = list(range(self.dataset_size))
  297. # add extra samples to make it evenly divisible
  298. indices += indices[:(self.total_size - len(indices))]
  299. assert len(indices) == self.total_size
  300. # subsample
  301. indices = indices[self.rank:self.total_size:self.num_replicas]
  302. assert len(indices) == self.num_samples
  303. return iter(indices)
  304. def __len__(self):
  305. return self.num_samples
  306. def set_epoch(self, epoch):
  307. self.epoch = epoch
  308. def create_yolo_dataset(image_dir, anno_path, batch_size=32, repeat_num=10, device_num=1, rank=0,
  309. is_training=True, num_parallel_workers=8):
  310. """Creatr YOLOv3 dataset with GeneratorDataset."""
  311. yolo_dataset = YoloDataset(image_dir=image_dir, anno_path=anno_path)
  312. distributed_sampler = DistributedSampler(yolo_dataset.dataset_size, batch_size, device_num, rank)
  313. ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler)
  314. ds.set_dataset_size(len(distributed_sampler))
  315. compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
  316. hwc_to_chw = P.HWC2CHW()
  317. ds = ds.map(input_columns=["image", "annotation"],
  318. output_columns=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
  319. columns_order=["image", "bbox_1", "bbox_2", "bbox_3", "gt_box1", "gt_box2", "gt_box3"],
  320. operations=compose_map_func, num_parallel_workers=num_parallel_workers)
  321. ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=num_parallel_workers)
  322. ds = ds.shuffle(buffer_size=256)
  323. ds = ds.batch(batch_size, drop_remainder=True)
  324. ds = ds.repeat(repeat_num)
  325. return ds

MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.

Contributors (1)