You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """CTPN dataset"""
  16. from __future__ import division
  17. import numpy as np
  18. from numpy import random
  19. import mmcv
  20. import mindspore.dataset as de
  21. import mindspore.dataset.vision.c_transforms as C
  22. import mindspore.dataset.transforms.c_transforms as CC
  23. import mindspore.common.dtype as mstype
  24. from src.config import config
  25. class PhotoMetricDistortion:
  26. """Photo Metric Distortion"""
  27. def __init__(self,
  28. brightness_delta=32,
  29. contrast_range=(0.5, 1.5),
  30. saturation_range=(0.5, 1.5),
  31. hue_delta=18):
  32. self.brightness_delta = brightness_delta
  33. self.contrast_lower, self.contrast_upper = contrast_range
  34. self.saturation_lower, self.saturation_upper = saturation_range
  35. self.hue_delta = hue_delta
  36. def __call__(self, img, boxes, labels):
  37. img = img.astype('float32')
  38. if random.randint(2):
  39. delta = random.uniform(-self.brightness_delta, self.brightness_delta)
  40. img += delta
  41. mode = random.randint(2)
  42. if mode == 1:
  43. if random.randint(2):
  44. alpha = random.uniform(self.contrast_lower,
  45. self.contrast_upper)
  46. img *= alpha
  47. # convert color from BGR to HSV
  48. img = mmcv.bgr2hsv(img)
  49. # random saturation
  50. if random.randint(2):
  51. img[..., 1] *= random.uniform(self.saturation_lower,
  52. self.saturation_upper)
  53. # random hue
  54. if random.randint(2):
  55. img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
  56. img[..., 0][img[..., 0] > 360] -= 360
  57. img[..., 0][img[..., 0] < 0] += 360
  58. # convert color from HSV to BGR
  59. img = mmcv.hsv2bgr(img)
  60. # random contrast
  61. if mode == 0:
  62. if random.randint(2):
  63. alpha = random.uniform(self.contrast_lower,
  64. self.contrast_upper)
  65. img *= alpha
  66. # randomly swap channels
  67. if random.randint(2):
  68. img = img[..., random.permutation(3)]
  69. return img, boxes, labels
  70. class Expand:
  71. """expand image"""
  72. def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)):
  73. if to_rgb:
  74. self.mean = mean[::-1]
  75. else:
  76. self.mean = mean
  77. self.min_ratio, self.max_ratio = ratio_range
  78. def __call__(self, img, boxes, labels):
  79. if random.randint(2):
  80. return img, boxes, labels
  81. h, w, c = img.shape
  82. ratio = random.uniform(self.min_ratio, self.max_ratio)
  83. expand_img = np.full((int(h * ratio), int(w * ratio), c),
  84. self.mean).astype(img.dtype)
  85. left = int(random.uniform(0, w * ratio - w))
  86. top = int(random.uniform(0, h * ratio - h))
  87. expand_img[top:top + h, left:left + w] = img
  88. img = expand_img
  89. boxes += np.tile((left, top), 2)
  90. return img, boxes, labels
  91. def rescale_column(img, gt_bboxes, gt_label, gt_num, img_shape):
  92. """rescale operation for image"""
  93. img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True)
  94. if img_data.shape[0] > config.img_height:
  95. img_data, scale_factor2 = mmcv.imrescale(img_data, (config.img_height, config.img_width), return_scale=True)
  96. scale_factor = scale_factor * scale_factor2
  97. img_shape = np.append(img_shape, scale_factor)
  98. img_shape = np.asarray(img_shape, dtype=np.float32)
  99. gt_bboxes = gt_bboxes * scale_factor
  100. gt_bboxes = split_gtbox_label(gt_bboxes)
  101. if gt_bboxes.shape[0] != 0:
  102. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
  103. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
  104. return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
  105. def resize_column(img, gt_bboxes, gt_label, gt_num, img_shape):
  106. """resize operation for image"""
  107. img_data = img
  108. img_data, w_scale, h_scale = mmcv.imresize(
  109. img_data, (config.img_width, config.img_height), return_scale=True)
  110. scale_factor = np.array(
  111. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  112. img_shape = (config.img_height, config.img_width, 1.0)
  113. img_shape = np.asarray(img_shape, dtype=np.float32)
  114. gt_bboxes = gt_bboxes * scale_factor
  115. gt_bboxes = split_gtbox_label(gt_bboxes)
  116. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
  117. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
  118. return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
  119. def resize_column_test(img, gt_bboxes, gt_label, gt_num, img_shape):
  120. """resize operation for image of eval"""
  121. img_data = img
  122. img_data, w_scale, h_scale = mmcv.imresize(
  123. img_data, (config.img_width, config.img_height), return_scale=True)
  124. scale_factor = np.array(
  125. [w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
  126. img_shape = (config.img_height, config.img_width)
  127. img_shape = np.append(img_shape, (h_scale, w_scale))
  128. img_shape = np.asarray(img_shape, dtype=np.float32)
  129. gt_bboxes = gt_bboxes * scale_factor
  130. shape = gt_bboxes.shape
  131. label_column = np.ones((shape[0], 1), dtype=int)
  132. gt_bboxes = np.concatenate((gt_bboxes, label_column), axis=1)
  133. gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1)
  134. gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1)
  135. return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
  136. def flipped_generation(img, gt_bboxes, gt_label, gt_num, img_shape):
  137. """flipped generation"""
  138. img_data = img
  139. flipped = gt_bboxes.copy()
  140. _, w, _ = img_data.shape
  141. flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1
  142. flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1
  143. return (img_data, flipped, gt_label, gt_num, img_shape)
  144. def image_bgr_rgb(img, gt_bboxes, gt_label, gt_num, img_shape):
  145. img_data = img[:, :, ::-1]
  146. return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
  147. def photo_crop_column(img, gt_bboxes, gt_label, gt_num, img_shape):
  148. """photo crop operation for image"""
  149. random_photo = PhotoMetricDistortion()
  150. img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label)
  151. return (img_data, gt_bboxes, gt_label, gt_num, img_shape)
  152. def expand_column(img, gt_bboxes, gt_label, gt_num, img_shape):
  153. """expand operation for image"""
  154. expand = Expand()
  155. img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label)
  156. return (img, gt_bboxes, gt_label, gt_num, img_shape)
  157. def split_gtbox_label(gt_bbox_total):
  158. """split ground truth box label"""
  159. gtbox_list = []
  160. box_num, _ = gt_bbox_total.shape
  161. for i in range(box_num):
  162. gt_bbox = gt_bbox_total[i]
  163. if gt_bbox[0] % 16 != 0:
  164. gt_bbox[0] = (gt_bbox[0] // 16) * 16
  165. if gt_bbox[2] % 16 != 0:
  166. gt_bbox[2] = (gt_bbox[2] // 16 + 1) * 16
  167. x0_array = np.arange(gt_bbox[0], gt_bbox[2], 16)
  168. for x0 in x0_array:
  169. gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1])
  170. return np.array(gtbox_list)
  171. def pad_label(img, gt_bboxes, gt_label, gt_valid, img_shape):
  172. """pad ground truth label"""
  173. pad_max_number = 256
  174. gt_label = gt_bboxes[:, 4]
  175. gt_valid = gt_bboxes[:, 4]
  176. if gt_bboxes.shape[0] < 256:
  177. gt_box = np.pad(gt_bboxes, ((0, pad_max_number - gt_bboxes.shape[0]), (0, 0)), \
  178. mode="constant", constant_values=0)
  179. gt_label = np.pad(gt_label, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=-1)
  180. gt_valid = np.pad(gt_valid, ((0, pad_max_number - gt_bboxes.shape[0])), mode="constant", constant_values=0)
  181. else:
  182. print("WARNING label num is high than 256")
  183. gt_box = gt_bboxes[0:pad_max_number]
  184. gt_label = gt_label[0:pad_max_number]
  185. gt_valid = gt_valid[0:pad_max_number]
  186. return (img, gt_box[:, :4], gt_label, gt_valid, img_shape)
  187. def preprocess_fn(image, box, is_training):
  188. """Preprocess function for dataset."""
  189. def _infer_data(image_bgr, gt_box_new, gt_label_new, gt_valid, image_shape):
  190. image_shape = image_shape[:2]
  191. input_data = image_bgr, gt_box_new, gt_label_new, gt_valid, image_shape
  192. if config.keep_ratio:
  193. input_data = rescale_column(*input_data)
  194. else:
  195. input_data = resize_column_test(*input_data)
  196. input_data = pad_label(*input_data)
  197. input_data = image_bgr_rgb(*input_data)
  198. output_data = input_data
  199. return output_data
  200. def _data_aug(image, box, is_training):
  201. """Data augmentation function."""
  202. image_bgr = image.copy()
  203. image_bgr[:, :, 0] = image[:, :, 2]
  204. image_bgr[:, :, 1] = image[:, :, 1]
  205. image_bgr[:, :, 2] = image[:, :, 0]
  206. image_shape = image_bgr.shape[:2]
  207. gt_box = box[:, :4]
  208. gt_label = box[:, 4]
  209. gt_valid = box[:, 4]
  210. input_data = image_bgr, gt_box, gt_label, gt_valid, image_shape
  211. if not is_training:
  212. return _infer_data(image_bgr, gt_box, gt_label, gt_valid, image_shape)
  213. expand = (np.random.rand() < config.expand_ratio)
  214. if expand:
  215. input_data = expand_column(*input_data)
  216. input_data = photo_crop_column(*input_data)
  217. if config.keep_ratio:
  218. input_data = rescale_column(*input_data)
  219. else:
  220. input_data = resize_column(*input_data)
  221. input_data = pad_label(*input_data)
  222. input_data = image_bgr_rgb(*input_data)
  223. output_data = input_data
  224. return output_data
  225. return _data_aug(image, box, is_training)
  226. def anno_parser(annos_str):
  227. """Parse annotation from string to list."""
  228. annos = []
  229. for anno_str in annos_str:
  230. anno = list(map(int, anno_str.strip().split(',')))
  231. annos.append(anno)
  232. return annos
  233. def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0,
  234. is_training=True, num_parallel_workers=12):
  235. """Creatr ctpn dataset with MindDataset."""
  236. ds = de.MindDataset(mindrecord_file, columns_list=["image", "annotation"], num_shards=device_num, shard_id=rank_id,\
  237. num_parallel_workers=num_parallel_workers, shuffle=is_training)
  238. decode = C.Decode()
  239. ds = ds.map(operations=decode, input_columns=["image"], num_parallel_workers=num_parallel_workers)
  240. compose_map_func = (lambda image, annotation: preprocess_fn(image, annotation, is_training))
  241. hwc_to_chw = C.HWC2CHW()
  242. normalize_op = C.Normalize((123.675, 116.28, 103.53), (58.395, 57.12, 57.375))
  243. type_cast0 = CC.TypeCast(mstype.float32)
  244. type_cast1 = CC.TypeCast(mstype.float16)
  245. type_cast2 = CC.TypeCast(mstype.int32)
  246. type_cast3 = CC.TypeCast(mstype.bool_)
  247. if is_training:
  248. ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"],
  249. output_columns=["image", "box", "label", "valid_num", "image_shape"],
  250. column_order=["image", "box", "label", "valid_num", "image_shape"],
  251. num_parallel_workers=num_parallel_workers,
  252. python_multiprocessing=True)
  253. ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"],
  254. num_parallel_workers=num_parallel_workers,
  255. python_multiprocessing=True)
  256. ds = ds.map(operations=[hwc_to_chw, type_cast1], input_columns=["image"],
  257. num_parallel_workers=num_parallel_workers,
  258. python_multiprocessing=True)
  259. else:
  260. ds = ds.map(operations=compose_map_func,
  261. input_columns=["image", "annotation"],
  262. output_columns=["image", "box", "label", "valid_num", "image_shape"],
  263. column_order=["image", "box", "label", "valid_num", "image_shape"],
  264. num_parallel_workers=num_parallel_workers,
  265. python_multiprocessing=True)
  266. ds = ds.map(operations=[normalize_op, hwc_to_chw, type_cast1], input_columns=["image"],
  267. num_parallel_workers=24)
  268. # transpose_column from python to c
  269. ds = ds.map(operations=[type_cast1], input_columns=["image_shape"])
  270. ds = ds.map(operations=[type_cast1], input_columns=["box"])
  271. ds = ds.map(operations=[type_cast2], input_columns=["label"])
  272. ds = ds.map(operations=[type_cast3], input_columns=["valid_num"])
  273. ds = ds.batch(batch_size, drop_remainder=True)
  274. ds = ds.repeat(repeat_num)
  275. return ds