You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. # Unless required by applicable law or agreed to in writing, software
  7. # distributed under the License is distributed on an "AS IS" BASIS,
  8. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # See the License for the specific language governing permissions and
  10. # limitations under the License.
  11. # ============================================================================
  12. import os
  13. import math
  14. import random
  15. import numpy as np
  16. import cv2
  17. from pycocotools.coco import COCO as ReadJson
  18. import mindspore.dataset as de
  19. from src.config import JointType, params
  20. cv2.setNumThreads(0)
  21. class txtdataset():
  22. def __init__(self, train, imgpath, maskpath, insize, mode='train', n_samples=None):
  23. self.train = train
  24. self.mode = mode
  25. self.imgpath = imgpath
  26. self.maskpath = maskpath
  27. self.insize = insize
  28. self.maxtime = 0
  29. self.catIds = train.getCatIds(catNms=['person'])
  30. self.imgIds = sorted(train.getImgIds(catIds=self.catIds))
  31. if self.mode == 'train':
  32. self.clean_imgIds()
  33. if self.mode in ['val', 'eval'] and n_samples is not None:
  34. self.imgIds = random.sample(self.imgIds, n_samples)
  35. print('{} images: {}'.format(mode, len(self)))
  36. def __len__(self):
  37. return len(self.imgIds)
  38. def clean_imgIds(self):
  39. print("cleaning imgids")
  40. for img_id in self.imgIds.copy():
  41. annotations = None
  42. anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)
  43. # annotation for that image
  44. if anno_ids:
  45. annotations_for_img = self.train.loadAnns(anno_ids)
  46. person_cnt = 0
  47. valid_annotations_for_img = []
  48. for annotation in annotations_for_img:
  49. # if too few keypoints or too small
  50. if annotation['num_keypoints'] >= params['min_keypoints'] and \
  51. annotation['area'] > params['min_area']:
  52. person_cnt += 1
  53. valid_annotations_for_img.append(annotation)
  54. # if person annotation
  55. if person_cnt > 0:
  56. annotations = valid_annotations_for_img
  57. if annotations is None:
  58. self.imgIds.remove(img_id)
  59. def overlay_paf(self, img, paf):
  60. hue = ((np.arctan2(paf[1], paf[0]) / np.pi) / -2 + 0.5)
  61. saturation = np.sqrt(paf[0] ** 2 + paf[1] ** 2)
  62. saturation[saturation > 1.0] = 1.0
  63. value = saturation.copy()
  64. hsv_paf = np.vstack((hue[np.newaxis], saturation[np.newaxis], value[np.newaxis])).transpose(1, 2, 0)
  65. rgb_paf = cv2.cvtColor((hsv_paf * 255).astype(np.uint8), cv2.COLOR_HSV2BGR)
  66. img = cv2.addWeighted(img, 0.6, rgb_paf, 0.4, 0)
  67. return img
  68. def overlay_pafs(self, img, pafs):
  69. mix_paf = np.zeros((2,) + img.shape[:-1])
  70. paf_flags = np.zeros(mix_paf.shape) # for constant paf
  71. for paf in pafs.reshape((int(pafs.shape[0]/2), 2,) + pafs.shape[1:]):
  72. paf_flags = paf != 0
  73. paf_flags += np.broadcast_to(paf_flags[0] | paf_flags[1], paf.shape)
  74. mix_paf += paf
  75. mix_paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
  76. img = self.overlay_paf(img, mix_paf)
  77. return img
  78. def overlay_heatmap(self, img, heatmap):
  79. rgb_heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
  80. img = cv2.addWeighted(img, 0.6, rgb_heatmap, 0.4, 0)
  81. return img
  82. def overlay_ignore_mask(self, img, ignore_mask):
  83. img = img * np.repeat((ignore_mask == 0).astype(np.uint8)[:, :, None], 3, axis=2)
  84. return img
  85. # -------------------- augment code --------------------------------
  86. def get_pose_bboxes(self, poses):
  87. pose_bboxes = []
  88. for pose in poses:
  89. x1 = pose[pose[:, 2] > 0][:, 0].min()
  90. y1 = pose[pose[:, 2] > 0][:, 1].min()
  91. x2 = pose[pose[:, 2] > 0][:, 0].max()
  92. y2 = pose[pose[:, 2] > 0][:, 1].max()
  93. pose_bboxes.append([x1, y1, x2, y2])
  94. pose_bboxes = np.array(pose_bboxes)
  95. return pose_bboxes
  96. def resize_data(self, img, ignore_mask, poses, shape):
  97. """resize img, mask and annotations"""
  98. img_h, img_w, _ = img.shape
  99. resized_img = cv2.resize(img, shape)
  100. ignore_mask = cv2.resize(ignore_mask.astype(np.uint8), shape).astype('bool')
  101. poses[:, :, :2] = (poses[:, :, :2] * np.array(shape) / np.array((img_w, img_h)))
  102. return resized_img, ignore_mask, poses
  103. def random_resize_img(self, img, ignore_mask, poses):
  104. h, w, _ = img.shape
  105. joint_bboxes = self.get_pose_bboxes(poses)
  106. bbox_sizes = ((joint_bboxes[:, 2:] - joint_bboxes[:, :2] + 1) ** 2).sum(axis=1) ** 0.5
  107. min_scale = params['min_box_size'] / bbox_sizes.min()
  108. max_scale = params['max_box_size'] / bbox_sizes.max()
  109. min_scale = min(max(min_scale, params['min_scale']), 1)
  110. max_scale = min(max(max_scale, 1), params['max_scale'])
  111. scale = float((max_scale - min_scale) * random.random() + min_scale)
  112. shape = (round(w * scale), round(h * scale))
  113. resized_img, resized_mask, resized_poses = self.resize_data(img, ignore_mask, poses, shape)
  114. return resized_img, resized_mask, resized_poses
  115. def random_rotate_img(self, img, mask, poses):
  116. h, w, _ = img.shape
  117. degree = np.random.randn() / 3 * params['max_rotate_degree']
  118. rad = degree * math.pi / 180
  119. center = (w / 2, h / 2)
  120. R = cv2.getRotationMatrix2D(center, degree, 1)
  121. bbox = (w * abs(math.cos(rad)) + h * abs(math.sin(rad)), w * abs(math.sin(rad)) + h * abs(math.cos(rad)))
  122. R[0, 2] += bbox[0] / 2 - center[0]
  123. R[1, 2] += bbox[1] / 2 - center[1]
  124. rotate_img = cv2.warpAffine(img, R, (int(bbox[0]+0.5), int(bbox[1]+0.5)), flags=cv2.INTER_CUBIC,
  125. borderMode=cv2.BORDER_CONSTANT, borderValue=[127.5, 127.5, 127.5])
  126. rotate_mask = cv2.warpAffine(mask.astype('uint8')*255, R, (int(bbox[0]+0.5), int(bbox[1]+0.5))) > 0
  127. tmp_poses = np.ones_like(poses)
  128. tmp_poses[:, :, :2] = poses[:, :, :2].copy()
  129. tmp_rotate_poses = np.dot(tmp_poses, R.T) # apply rotation matrix to the poses
  130. rotate_poses = poses.copy() # to keep visibility flag
  131. rotate_poses[:, :, :2] = tmp_rotate_poses
  132. return rotate_img, rotate_mask, rotate_poses
  133. def random_crop_img(self, img, ignore_mask, poses):
  134. h, w, _ = img.shape
  135. insize = self.insize
  136. joint_bboxes = self.get_pose_bboxes(poses)
  137. bbox = random.choice(joint_bboxes) # select a bbox randomly
  138. bbox_center = bbox[:2] + (bbox[2:] - bbox[:2]) / 2
  139. r_xy = np.random.rand(2)
  140. perturb = ((r_xy - 0.5) * 2 * params['center_perterb_max'])
  141. center = (bbox_center + perturb + 0.5).astype('i')
  142. crop_img = np.zeros((insize, insize, 3), 'uint8') + 127.5
  143. crop_mask = np.zeros((insize, insize), 'bool')
  144. offset = (center - (insize - 1) / 2 + 0.5).astype('i')
  145. offset_ = (center + (insize - 1) / 2 - (w - 1, h - 1) + 0.5).astype('i')
  146. x1, y1 = (center - (insize-1)/2 + 0.5).astype('i')
  147. x2, y2 = (center + (insize-1)/2 + 0.5).astype('i')
  148. x1 = max(x1, 0)
  149. y1 = max(y1, 0)
  150. x2 = min(x2, w-1)
  151. y2 = min(y2, h-1)
  152. x_from = -offset[0] if offset[0] < 0 else 0
  153. y_from = -offset[1] if offset[1] < 0 else 0
  154. x_to = insize - offset_[0] - 1 if offset_[0] >= 0 else insize - 1
  155. y_to = insize - offset_[1] - 1 if offset_[1] >= 0 else insize - 1
  156. crop_img[y_from:y_to+1, x_from:x_to+1] = img[y1:y2+1, x1:x2+1].copy()
  157. crop_mask[y_from:y_to+1, x_from:x_to+1] = ignore_mask[y1:y2+1, x1:x2+1].copy()
  158. poses[:, :, :2] -= offset
  159. return crop_img.astype('uint8'), crop_mask, poses
  160. def distort_color(self, img):
  161. img_max = np.broadcast_to(np.array(255, dtype=np.uint8), img.shape[:-1])
  162. img_min = np.zeros(img.shape[:-1], dtype=np.uint8)
  163. hsv_img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV).astype(np.int32)
  164. hsv_img[:, :, 0] = np.maximum(np.minimum(hsv_img[:, :, 0] - 10 + np.random.randint(20 + 1), img_max), img_min) # hue
  165. hsv_img[:, :, 1] = np.maximum(np.minimum(hsv_img[:, :, 1] - 40 + np.random.randint(80 + 1), img_max), img_min) # saturation
  166. hsv_img[:, :, 2] = np.maximum(np.minimum(hsv_img[:, :, 2] - 30 + np.random.randint(60 + 1), img_max), img_min) # value
  167. hsv_img = hsv_img.astype(np.uint8)
  168. distorted_img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
  169. return distorted_img
  170. def flip_img(self, img, mask, poses):
  171. flipped_img = cv2.flip(img, 1)
  172. flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool')
  173. poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0]
  174. def swap_joints(poses, joint_type_, joint_type_2):
  175. tmp = poses[:, joint_type_].copy()
  176. poses[:, joint_type_] = poses[:, joint_type_2]
  177. poses[:, joint_type_2] = tmp
  178. swap_joints(poses, JointType.LeftEye, JointType.RightEye)
  179. swap_joints(poses, JointType.LeftEar, JointType.RightEar)
  180. swap_joints(poses, JointType.LeftShoulder, JointType.RightShoulder)
  181. swap_joints(poses, JointType.LeftElbow, JointType.RightElbow)
  182. swap_joints(poses, JointType.LeftHand, JointType.RightHand)
  183. swap_joints(poses, JointType.LeftWaist, JointType.RightWaist)
  184. swap_joints(poses, JointType.LeftKnee, JointType.RightKnee)
  185. swap_joints(poses, JointType.LeftFoot, JointType.RightFoot)
  186. return flipped_img, flipped_mask, poses
  187. def augment_data(self, img, ignore_mask, poses):
  188. aug_img = img.copy()
  189. aug_img, ignore_mask, poses = self.random_resize_img(aug_img, ignore_mask, poses)
  190. aug_img, ignore_mask, poses = self.random_rotate_img(aug_img, ignore_mask, poses)
  191. aug_img, ignore_mask, poses = self.random_crop_img(aug_img, ignore_mask, poses)
  192. if np.random.randint(2):
  193. aug_img = self.distort_color(aug_img)
  194. if np.random.randint(2):
  195. aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses)
  196. return aug_img, ignore_mask, poses
  197. # ------------------------------- end -----------------------------------
  198. # ------------------------------ Heatmap ------------------------------------
  199. # return shape: (height, width)
  200. def generate_gaussian_heatmap(self, shape, joint, sigma):
  201. x, y = joint
  202. grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
  203. grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
  204. grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
  205. gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
  206. return gaussian_heatmap
  207. def generate_heatmaps(self, img, poses, heatmap_sigma):
  208. heatmaps = np.zeros((0,) + img.shape[:-1])
  209. sum_heatmap = np.zeros(img.shape[:-1])
  210. for joint_index in range(len(JointType)):
  211. heatmap = np.zeros(img.shape[:-1])
  212. for pose in poses:
  213. if pose[joint_index, 2] > 0:
  214. jointmap = self.generate_gaussian_heatmap(img.shape[:-1], pose[joint_index][:2], heatmap_sigma)
  215. heatmap[jointmap > heatmap] = jointmap[jointmap > heatmap]
  216. sum_heatmap[jointmap > sum_heatmap] = jointmap[jointmap > sum_heatmap]
  217. heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))
  218. bg_heatmap = 1 - sum_heatmap # background channel
  219. heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
  220. return heatmaps.astype('f')
  221. def generate_gaussian_heatmap_fast(self, shape, joint, sigma):
  222. x, y = joint
  223. grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
  224. grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
  225. grid_x = grid_x + 0.4375
  226. grid_y = grid_y + 0.4375
  227. grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
  228. gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
  229. return gaussian_heatmap
  230. def generate_heatmaps_fast(self, img, poses, heatmap_sigma):
  231. resize_shape = (img.shape[0] // 8, img.shape[1] // 8)
  232. heatmaps = np.zeros((0,) + resize_shape)
  233. sum_heatmap = np.zeros(resize_shape)
  234. for joint_index in range(len(JointType)):
  235. heatmap = np.zeros(resize_shape)
  236. for pose in poses:
  237. if pose[joint_index, 2] > 0:
  238. jointmap = self.generate_gaussian_heatmap_fast(resize_shape, pose[joint_index][:2]/8,
  239. heatmap_sigma/8)
  240. index_1 = jointmap > heatmap
  241. heatmap[index_1] = jointmap[index_1]
  242. index_2 = jointmap > sum_heatmap
  243. sum_heatmap[index_2] = jointmap[index_2]
  244. heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))
  245. bg_heatmap = 1 - sum_heatmap # background channel
  246. heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
  247. return heatmaps.astype('f')
  248. # ------------------------------ end ------------------------------------
  249. # ------------------------------ PAF ------------------------------------
  250. # return shape: (2, height, width)
  251. def generate_constant_paf(self, shape, joint_from, joint_to, paf_width):
  252. if np.array_equal(joint_from, joint_to): # same joint
  253. return np.zeros((2,) + shape[:-1])
  254. joint_distance = np.linalg.norm(joint_to - joint_from)
  255. unit_vector = (joint_to - joint_from) / joint_distance
  256. rad = np.pi / 2
  257. # [[0, 1], [-1, 0]]
  258. rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
  259. # [[u_y], [-u_x]]
  260. vertical_unit_vector = np.dot(rot_matrix, unit_vector)
  261. grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
  262. grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
  263. horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
  264. horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
  265. vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
  266. (grid_y - joint_from[1])
  267. vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8
  268. paf_flag = horizontal_paf_flag & vertical_paf_flag
  269. constant_paf = np.stack((paf_flag, paf_flag)) *\
  270. np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)
  271. return constant_paf
  272. def generate_pafs(self, img, poses, paf_sigma):
  273. pafs = np.zeros((0,) + img.shape[:-1])
  274. for limb in params['limbs_point']:
  275. paf = np.zeros((2,) + img.shape[:-1])
  276. paf_flags = np.zeros(paf.shape) # for constant paf
  277. for pose in poses:
  278. joint_from, joint_to = pose[limb]
  279. if joint_from[2] > 0 and joint_to[2] > 0:
  280. limb_paf = self.generate_constant_paf(img.shape, joint_from[:2], joint_to[:2], paf_sigma) # [2, 368, 368]
  281. limb_paf_flags = limb_paf != 0
  282. paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)
  283. paf += limb_paf
  284. paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
  285. pafs = np.vstack((pafs, paf))
  286. return pafs.astype('f')
  287. def generate_constant_paf_fast(self, shape, joint_from, joint_to, paf_width):
  288. if np.array_equal(joint_from, joint_to): # same joint
  289. return np.zeros((2,) + shape[:-1])
  290. joint_distance = np.linalg.norm(joint_to - joint_from)
  291. unit_vector = (joint_to - joint_from) / joint_distance
  292. rad = np.pi / 2
  293. # [[0, 1], [-1, 0]]
  294. rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
  295. # [[u_y], [-u_x]]
  296. vertical_unit_vector = np.dot(rot_matrix, unit_vector)
  297. grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
  298. grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
  299. grid_x = grid_x + 0.4375
  300. grid_y = grid_y + 0.4375
  301. horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
  302. horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
  303. vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
  304. (grid_y - joint_from[1])
  305. vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8/8 = 1
  306. paf_flag = horizontal_paf_flag & vertical_paf_flag
  307. constant_paf = np.stack((paf_flag, paf_flag)) *\
  308. np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)
  309. return constant_paf
  310. def generate_pafs_fast(self, img, poses, paf_sigma):
  311. resize_shape = (img.shape[0]//8, img.shape[1]//8, 3)
  312. pafs = np.zeros((0,) + resize_shape[:-1])
  313. for limb in params['limbs_point']:
  314. paf = np.zeros((2,) + resize_shape[:-1])
  315. paf_flags = np.zeros(paf.shape) # for constant paf
  316. for pose in poses:
  317. joint_from, joint_to = pose[limb]
  318. if joint_from[2] > 0 and joint_to[2] > 0:
  319. limb_paf = self.generate_constant_paf_fast(resize_shape, joint_from[:2]/8, joint_to[:2]/8, paf_sigma/8) # [2, 368, 368]
  320. limb_paf_flags = limb_paf != 0
  321. paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)
  322. paf += limb_paf
  323. index_1 = paf_flags > 0
  324. paf[index_1] /= paf_flags[index_1]
  325. pafs = np.vstack((pafs, paf))
  326. return pafs.astype('f')
  327. # ------------------------------ end ------------------------------------
  328. def get_img_annotation(self, ind=None, img_id=None):
  329. annotations = None
  330. if ind is not None:
  331. img_id = self.imgIds[ind]
  332. anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)
  333. # annotation for that image
  334. if anno_ids:
  335. annotations_for_img = self.train.loadAnns(anno_ids)
  336. person_cnt = 0
  337. valid_annotations_for_img = []
  338. for annotation in annotations_for_img:
  339. # if too few keypoints or too small
  340. if annotation['num_keypoints'] >= params['min_keypoints'] and annotation['area'] > params['min_area']:
  341. person_cnt += 1
  342. valid_annotations_for_img.append(annotation)
  343. # if person annotation
  344. if person_cnt > 0:
  345. annotations = valid_annotations_for_img
  346. img_path = os.path.join(self.imgpath, self.train.loadImgs([img_id])[0]['file_name'])
  347. mask_path = os.path.join(self.maskpath, '{:012d}.png'.format(img_id))
  348. img = cv2.imread(img_path)
  349. ignore_mask = cv2.imread(mask_path, 0)
  350. if ignore_mask is None:
  351. ignore_mask = np.zeros(img.shape[:2], np.float32)
  352. else:
  353. ignore_mask[ignore_mask == 255] = 1
  354. if self.mode == 'eval':
  355. return img, img_id, annotations_for_img, ignore_mask
  356. return img, img_id, annotations, ignore_mask.astype('f')
  357. def parse_annotation(self, annotations):
  358. poses = np.zeros((0, len(JointType), 3), dtype=np.int32)
  359. for ann in annotations:
  360. ann_pose = np.array(ann['keypoints']).reshape(-1, 3)
  361. pose = np.zeros((1, len(JointType), 3), dtype=np.int32)
  362. # convert poses position
  363. for i, joint_index in enumerate(params['joint_indices']):
  364. pose[0][joint_index] = ann_pose[i]
  365. # compute neck position
  366. if pose[0][JointType.LeftShoulder][2] > 0 and pose[0][JointType.RightShoulder][2] > 0:
  367. pose[0][JointType.Neck][0] = int((pose[0][JointType.LeftShoulder][0] +
  368. pose[0][JointType.RightShoulder][0]) / 2)
  369. pose[0][JointType.Neck][1] = int((pose[0][JointType.LeftShoulder][1] +
  370. pose[0][JointType.RightShoulder][1]) / 2)
  371. pose[0][JointType.Neck][2] = 2
  372. poses = np.vstack((poses, pose))
  373. return poses
  374. def resize_output(self, input_np, map_h=46, map_w=46):
  375. if len(input_np.shape) == 3:
  376. output = np.zeros((input_np.shape[0], map_h, map_w))
  377. for i in range(input_np.shape[0]):
  378. output[i] = cv2.resize(input_np[i], (map_w, map_h))
  379. return output.astype('f')
  380. input_np = input_np.astype('f')
  381. output = cv2.resize(input_np, (map_h, map_w))
  382. return output
  383. def generate_labels(self, img, poses, ignore_mask):
  384. img, ignore_mask, poses = self.augment_data(img, ignore_mask, poses)
  385. resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses,
  386. shape=(self.insize, self.insize))
  387. resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, params['heatmap_sigma'])
  388. resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, params['paf_sigma'])
  389. ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool')
  390. resized_ignore_mask = self.resize_output(ignore_mask)
  391. return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask
  392. def preprocess(self, img):
  393. x_data = img.astype('f')
  394. x_data /= 255
  395. x_data -= 0.5
  396. x_data = x_data.transpose(2, 0, 1)
  397. return x_data
  398. def __getitem__(self, i):
  399. img, img_id, annotations, ignore_mask = self.get_img_annotation(ind=i)
  400. if self.mode in ['eval', 'val']:
  401. # don't need to make heatmaps/pafs
  402. return img, np.array([img_id])
  403. # if no annotations are available
  404. while annotations is None:
  405. print("none annotations", img_id)
  406. img_id = self.imgIds[np.random.randint(len(self))]
  407. img, img_id, annotations, ignore_mask = self.get_img_annotation(img_id=img_id)
  408. poses = self.parse_annotation(annotations)
  409. # TEST
  410. # return img, poses, ignore_mask
  411. resized_img, pafs, heatmaps, ignore_mask = self.generate_labels(img, poses, ignore_mask)
  412. resized_img = self.preprocess(resized_img)
  413. ignore_mask = 1. - ignore_mask
  414. # # TEST
  415. # print("Shape: ", resized_img.dtype, " ", pafs.dtype, " ", heatmaps.dtype, " ", ignore_mask.dtype)
  416. return resized_img, pafs, heatmaps, ignore_mask
  417. class DistributedSampler():
  418. def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
  419. self.dataset = dataset
  420. self.rank = rank
  421. self.group_size = group_size
  422. self.dataset_len = len(self.dataset)
  423. self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
  424. self.total_size = self.num_samplers * self.group_size
  425. self.shuffle = shuffle
  426. self.seed = seed
  427. def __iter__(self):
  428. if self.shuffle:
  429. self.seed = (self.seed + 1) & 0xffffffff
  430. np.random.seed(self.seed)
  431. indices = np.random.permutation(self.dataset_len).tolist()
  432. else:
  433. indices = list(range(len(self.dataset_len)))
  434. indices += indices[:(self.total_size - len(indices))]
  435. indices = indices[self.rank::self.group_size]
  436. return iter(indices)
  437. def __len__(self):
  438. return self.num_samplers
  439. def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''):
  440. #cv2.setNumThreads(0)
  441. val = ReadJson(jsonpath)
  442. dataset = txtdataset(val, imgpath, maskpath, params['insize'], mode=mode)
  443. sampler = DistributedSampler(dataset, rank, group_size)
  444. ds = de.GeneratorDataset(dataset, ['img', 'img_id'], num_parallel_workers=8, sampler=sampler)
  445. ds = ds.repeat(1)
  446. return ds
  447. def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True,
  448. multiprocessing=True, num_worker=20):
  449. train = ReadJson(jsonpath)
  450. dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode)
  451. if group_size == 1:
  452. de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
  453. shuffle=shuffle,
  454. num_parallel_workers=num_worker,
  455. python_multiprocessing=multiprocessing)
  456. else:
  457. de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
  458. shuffle=shuffle,
  459. num_parallel_workers=num_worker,
  460. python_multiprocessing=multiprocessing,
  461. num_shards=group_size,
  462. shard_id=rank)
  463. de_dataset = de_dataset.batch(batch_size=batch_size, drop_remainder=True)
  464. de_dataset = de_dataset.repeat(repeat_num)
  465. return de_dataset