You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import math
  16. import os
  17. import random
  18. import Polygon as plg
  19. import cv2
  20. import numpy as np
  21. import pyclipper
  22. from PIL import Image
  23. from src.config import config
  24. import mindspore.dataset.engine as de
  25. import mindspore.dataset.vision.py_transforms as py_transforms
  26. __all__ = ['train_dataset_creator', 'test_dataset_creator']
  27. def get_img(img_path):
  28. img = cv2.imread(img_path)
  29. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  30. return img
  31. def get_imgs_names(root_dir):
  32. img_paths = [i for i in os.listdir(root_dir)
  33. if os.path.splitext(i)[-1].lower() in ['.jpg', '.jpeg', '.png']]
  34. return img_paths
  35. def get_bboxes(img, gt_path):
  36. h, w = img.shape[0:2]
  37. with open(gt_path, 'r', encoding='utf-8-sig') as f:
  38. lines = f.readlines()
  39. bboxes = []
  40. tags = []
  41. for line in lines:
  42. line = line.replace('\xef\xbb\xbf', '')
  43. line = line.replace('\ufeff', '')
  44. line = line.replace('\n', '')
  45. gt = line.split(",", 8)
  46. tag = gt[-1][0] != '#'
  47. box = [int(gt[i]) for i in range(8)]
  48. box = np.asarray(box) / ([w * 1.0, h * 1.0] * 4)
  49. bboxes.append(box)
  50. tags.append(tag)
  51. return np.array(bboxes), tags
  52. def random_scale(img, min_size):
  53. h, w = img.shape[0:2]
  54. if max(h, w) > 1280:
  55. scale1 = 1280.0 / max(h, w)
  56. img = cv2.resize(img, dsize=None, fx=scale1, fy=scale1)
  57. h, w = img.shape[0:2]
  58. random_scale1 = np.array([0.5, 1.0, 2.0, 3.0])
  59. scale2 = np.random.choice(random_scale1)
  60. if min(h, w) * scale2 <= min_size:
  61. scale3 = (min_size + 10) * 1.0 / min(h, w)
  62. img = cv2.resize(img, dsize=None, fx=scale3, fy=scale3)
  63. else:
  64. img = cv2.resize(img, dsize=None, fx=scale2, fy=scale2)
  65. return img
  66. def random_horizontal_flip(imgs):
  67. if random.random() < 0.5:
  68. for i, _ in enumerate(imgs):
  69. imgs[i] = np.flip(imgs[i], axis=1).copy()
  70. return imgs
  71. def random_rotate(imgs):
  72. max_angle = 10
  73. angle = random.random() * 2 * max_angle - max_angle
  74. for i, _ in enumerate(imgs):
  75. img = imgs[i]
  76. w, h = img.shape[:2]
  77. rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
  78. img_rotation = cv2.warpAffine(img, rotation_matrix, (h, w))
  79. imgs[i] = img_rotation
  80. return imgs
  81. def random_crop(imgs, img_size):
  82. h, w = imgs[0].shape[0:2]
  83. th, tw = img_size
  84. if w == tw and h == th:
  85. return imgs
  86. if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
  87. tl = np.min(np.where(imgs[1] > 0), axis=1) - img_size
  88. tl[tl < 0] = 0
  89. br = np.max(np.where(imgs[1] > 0), axis=1) - img_size
  90. br[br < 0] = 0
  91. br[0] = min(br[0], h - th)
  92. br[1] = min(br[1], w - tw)
  93. i = random.randint(tl[0], br[0])
  94. j = random.randint(tl[1], br[1])
  95. else:
  96. i = random.randint(0, h - th)
  97. j = random.randint(0, w - tw)
  98. for idx, _ in enumerate(imgs):
  99. if len(imgs[idx].shape) == 3:
  100. imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
  101. else:
  102. imgs[idx] = imgs[idx][i:i + th, j:j + tw]
  103. return imgs
  104. def scale(img, long_size=2240):
  105. h, w = img.shape[0:2]
  106. scale_long = long_size * 1.0 / max(h, w)
  107. img = cv2.resize(img, dsize=None, fx=scale_long, fy=scale_long)
  108. return img
  109. def dist(a, b):
  110. return np.sqrt(np.sum((a - b) ** 2))
  111. def perimeter(bbox):
  112. peri = 0.0
  113. for i in range(bbox.shape[0]):
  114. peri += dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
  115. return peri
  116. def shrink(bboxes, rate, max_shr=20):
  117. rate = rate * rate
  118. shrinked_bboxes = []
  119. for bbox in bboxes:
  120. area = plg.Polygon(bbox).area()
  121. peri = perimeter(bbox)
  122. pco = pyclipper.PyclipperOffset()
  123. pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  124. offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
  125. shrinked_bbox = pco.Execute(-offset)
  126. if not shrinked_bbox:
  127. shrinked_bboxes.append(bbox)
  128. continue
  129. shrinked_bbox = np.array(shrinked_bbox)[0]
  130. if shrinked_bbox.shape[0] <= 2:
  131. shrinked_bboxes.append(bbox)
  132. continue
  133. shrinked_bboxes.append(shrinked_bbox)
  134. return np.array(shrinked_bboxes)
  135. class TrainDataset:
  136. def __init__(self):
  137. self.is_transform = config.TRAIN_IS_TRANSFORM
  138. self.img_size = config.TRAIN_LONG_SIZE
  139. self.kernel_num = config.KERNEL_NUM
  140. self.min_scale = config.TRAIN_MIN_SCALE
  141. root_dir = os.path.join(os.path.join(os.path.dirname(__file__), '..'), config.TRAIN_ROOT_DIR)
  142. ic15_train_data_dir = root_dir + 'ch4_training_images/'
  143. ic15_train_gt_dir = root_dir + 'ch4_training_localization_transcription_gt/'
  144. self.img_size = self.img_size if \
  145. (self.img_size is None or isinstance(self.img_size, tuple)) \
  146. else (self.img_size, self.img_size)
  147. data_dirs = [ic15_train_data_dir]
  148. gt_dirs = [ic15_train_gt_dir]
  149. self.all_img_paths = []
  150. self.all_gt_paths = []
  151. for data_dir, gt_dir in zip(data_dirs, gt_dirs):
  152. img_names = [i for i in os.listdir(data_dir)
  153. if os.path.splitext(i)[-1].lower()
  154. in ['.jpg', '.jpeg', '.png']]
  155. img_paths = []
  156. gt_paths = []
  157. for _, img_name in enumerate(img_names):
  158. img_path = os.path.join(data_dir, img_name)
  159. gt_name = 'gt_' + img_name.split('.')[0] + '.txt'
  160. gt_path = os.path.join(gt_dir, gt_name)
  161. img_paths.append(img_path)
  162. gt_paths.append(gt_path)
  163. self.all_img_paths.extend(img_paths)
  164. self.all_gt_paths.extend(gt_paths)
  165. def __getitem__(self, index):
  166. img_path = self.all_img_paths[index]
  167. gt_path = self.all_gt_paths[index]
  168. img = get_img(img_path)
  169. bboxes, tags = get_bboxes(img, gt_path)
  170. # multi-scale training
  171. if self.is_transform:
  172. img = random_scale(img, min_size=self.img_size[0])
  173. # get gt_text and training_mask
  174. img_h, img_w = img.shape[0: 2]
  175. gt_text = np.zeros((img_h, img_w), dtype=np.float32)
  176. training_mask = np.ones((img_h, img_w), dtype=np.float32)
  177. if bboxes.shape[0] > 0:
  178. bboxes = np.reshape(bboxes * ([img_w, img_h] * 4), (bboxes.shape[0], -1, 2)).astype('int32')
  179. for i in range(bboxes.shape[0]):
  180. cv2.drawContours(gt_text, [bboxes[i]], 0, i + 1, -1)
  181. if not tags[i]:
  182. cv2.drawContours(training_mask, [bboxes[i]], 0, 0, -1)
  183. # get gt_kernels
  184. gt_kernels = []
  185. for i in range(1, self.kernel_num):
  186. rate = 1.0 - (1.0 - self.min_scale) / (self.kernel_num - 1) * i
  187. gt_kernel = np.zeros(img.shape[0:2], dtype=np.float32)
  188. kernel_bboxes = shrink(bboxes, rate)
  189. for j in range(kernel_bboxes.shape[0]):
  190. cv2.drawContours(gt_kernel, [kernel_bboxes[j]], 0, 1, -1)
  191. gt_kernels.append(gt_kernel)
  192. # data augmentation
  193. if self.is_transform:
  194. imgs = [img, gt_text, training_mask]
  195. imgs.extend(gt_kernels)
  196. imgs = random_horizontal_flip(imgs)
  197. imgs = random_rotate(imgs)
  198. imgs = random_crop(imgs, self.img_size)
  199. img, gt_text, training_mask, gt_kernels = imgs[0], imgs[1], imgs[2], imgs[3:]
  200. gt_text[gt_text > 0] = 1
  201. gt_kernels = np.array(gt_kernels)
  202. if self.is_transform:
  203. img = Image.fromarray(img)
  204. img = img.convert('RGB')
  205. img = py_transforms.RandomColorAdjust(brightness=32.0 / 255, saturation=0.5)(img)
  206. else:
  207. img = Image.fromarray(img)
  208. img = img.convert('RGB')
  209. img = py_transforms.ToTensor()(img)
  210. img = py_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
  211. gt_text = gt_text.astype(np.float32)
  212. gt_kernels = gt_kernels.astype(np.float32)
  213. training_mask = training_mask.astype(np.float32)
  214. return img, gt_text, gt_kernels, training_mask
  215. def __len__(self):
  216. return len(self.all_img_paths)
  217. def IC15_TEST_Generator():
  218. ic15_test_data_dir = config.TEST_ROOT_DIR + 'ch4_test_images/'
  219. img_size = config.INFER_LONG_SIZE
  220. img_size = img_size if (img_size is None or isinstance(img_size, tuple)) else (img_size, img_size)
  221. data_dirs = [ic15_test_data_dir]
  222. all_img_paths = []
  223. for data_dir in data_dirs:
  224. img_names = [i for i in os.listdir(data_dir) if os.path.splitext(i)[-1].lower() in ['.jpg', '.jpeg', '.png']]
  225. img_paths = []
  226. for _, img_name in enumerate(img_names):
  227. img_path = data_dir + img_name
  228. img_paths.append(img_path)
  229. all_img_paths.extend(img_paths)
  230. dataset_length = len(all_img_paths)
  231. for index in range(dataset_length):
  232. img_path = all_img_paths[index]
  233. img_name = np.array(os.path.split(img_path)[-1])
  234. img = get_img(img_path)
  235. long_size = max(img.shape[:2])
  236. img_resized = np.zeros((long_size, long_size, 3), np.uint8)
  237. img_resized[:img.shape[0], :img.shape[1], :] = img
  238. img_resized = cv2.resize(img_resized, dsize=img_size)
  239. img_resized = Image.fromarray(img_resized)
  240. img_resized = img_resized.convert('RGB')
  241. img_resized = py_transforms.ToTensor()(img_resized)
  242. img_resized = py_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_resized)
  243. yield img, img_resized, img_name
  244. class DistributedSampler():
  245. def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
  246. self.dataset = dataset
  247. self.rank = rank
  248. self.group_size = group_size
  249. self.dataset_len = len(self.dataset)
  250. self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
  251. self.total_size = self.num_samplers * self.group_size
  252. self.shuffle = shuffle
  253. self.seed = seed
  254. def __iter__(self):
  255. if self.shuffle:
  256. self.seed = (self.seed + 1) & 0xffffffff
  257. np.random.seed(self.seed)
  258. indices = np.random.permutation(self.dataset_len).tolist()
  259. else:
  260. indices = list(range(len(self.dataset_len)))
  261. indices += indices[:(self.total_size - len(indices))]
  262. indices = indices[self.rank::self.group_size]
  263. return iter(indices)
  264. def __len__(self):
  265. return self.num_samplers
  266. def train_dataset_creator(rank, group_size, shuffle=True):
  267. cv2.setNumThreads(0)
  268. dataset = TrainDataset()
  269. sampler = DistributedSampler(dataset, rank, group_size, shuffle)
  270. ds = de.GeneratorDataset(dataset, ['img', 'gt_text', 'gt_kernels', 'training_mask'], num_parallel_workers=8,
  271. sampler=sampler)
  272. ds = ds.repeat(1)
  273. ds = ds.batch(config.TRAIN_BATCH_SIZE, drop_remainder=config.TRAIN_DROP_REMAINDER)
  274. return ds
  275. def test_dataset_creator():
  276. ds = de.GeneratorDataset(IC15_TEST_Generator, ['img', 'img_resized', 'img_name'])
  277. ds = ds.shuffle(config.TEST_BUFFER_SIZE)
  278. ds = ds.batch(1, drop_remainder=config.TEST_DROP_REMAINDER)
  279. return ds